diff --git a/cpp/src/arrow/acero/CMakeLists.txt b/cpp/src/arrow/acero/CMakeLists.txt index dc18afa9797..0b5029b7c79 100644 --- a/cpp/src/arrow/acero/CMakeLists.txt +++ b/cpp/src/arrow/acero/CMakeLists.txt @@ -33,6 +33,7 @@ set(ARROW_ACERO_REQUIRED_DEPENDENCIES Arrow ArrowCompute) set(ARROW_ACERO_SRCS accumulation_queue.cc + backpressure.cc scalar_aggregate_node.cc groupby_aggregate_node.cc aggregate_internal.cc @@ -60,6 +61,7 @@ set(ARROW_ACERO_SRCS time_series_util.cc tpch_node.cc union_node.cc + pipe_node.cc util.cc) append_runtime_avx2_src(ARROW_ACERO_SRCS bloom_filter_avx2.cc) @@ -173,6 +175,7 @@ function(add_arrow_acero_test REL_TEST_NAME) ${ARG_UNPARSED_ARGUMENTS}) endfunction() +add_arrow_acero_test(backpressure_test SOURCES backpressure_test.cc) add_arrow_acero_test(plan_test SOURCES plan_test.cc test_nodes_test.cc) add_arrow_acero_test(source_node_test SOURCES source_node_test.cc) add_arrow_acero_test(fetch_node_test SOURCES fetch_node_test.cc) @@ -183,6 +186,7 @@ add_arrow_acero_test(pivot_longer_node_test SOURCES pivot_longer_node_test.cc) add_arrow_acero_test(asof_join_node_test SOURCES asof_join_node_test.cc) add_arrow_acero_test(sorted_merge_node_test SOURCES sorted_merge_node_test.cc) +add_arrow_acero_test(pipe_node_test SOURCES pipe_node_test.cc) add_arrow_acero_test(tpch_node_test SOURCES tpch_node_test.cc) add_arrow_acero_test(union_node_test SOURCES union_node_test.cc) diff --git a/cpp/src/arrow/acero/asof_join_node.cc b/cpp/src/arrow/acero/asof_join_node.cc index 3970050e502..91920faa793 100644 --- a/cpp/src/arrow/acero/asof_join_node.cc +++ b/cpp/src/arrow/acero/asof_join_node.cc @@ -46,6 +46,7 @@ #ifndef NDEBUG # include "arrow/compute/function_internal.h" #endif +#include "arrow/acero/backpressure.h" #include "arrow/acero/time_series_util.h" #include "arrow/compute/key_hash_internal.h" #include "arrow/compute/light_array_internal.h" @@ -459,21 +460,6 @@ class KeyHasher { arrow::util::TempVectorStack stack_; }; -class BackpressureController : public BackpressureControl { - public: - BackpressureController(ExecNode* node, ExecNode* output, - std::atomic& backpressure_counter) - : node_(node), output_(output), backpressure_counter_(backpressure_counter) {} - - void Pause() override { node_->PauseProducing(output_, ++backpressure_counter_); } - void Resume() override { node_->ResumeProducing(output_, ++backpressure_counter_); } - - private: - ExecNode* node_; - ExecNode* output_; - std::atomic& backpressure_counter_; -}; - class InputState : public util::SerialSequencingQueue::Processor { // InputState corresponds to an input // Input record batches are queued up in InputState until processed and diff --git a/cpp/src/arrow/acero/backpressure.cc b/cpp/src/arrow/acero/backpressure.cc new file mode 100644 index 00000000000..70853e4c05b --- /dev/null +++ b/cpp/src/arrow/acero/backpressure.cc @@ -0,0 +1,108 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/acero/backpressure.h" +#include "arrow/acero/exec_plan.h" + +namespace arrow::acero { + +BackpressureController::BackpressureController(ExecNode* node, ExecNode* output, + std::atomic& backpressure_counter) + : node_(node), output_(output), backpressure_counter_(backpressure_counter) {} +void BackpressureController::Pause() { + node_->PauseProducing(output_, ++backpressure_counter_); +} +void BackpressureController::Resume() { + node_->ResumeProducing(output_, ++backpressure_counter_); +} + +BackpressureCombiner::BackpressureCombiner( + std::unique_ptr backpressure_control, bool pause_on_any) + : pause_on_any_(pause_on_any), + backpressure_control_(std::move(backpressure_control)) {} + +// Called from Source nodes +void BackpressureCombiner::Pause(Source* output) { + std::lock_guard lg(mutex_); + if (!paused_[output]) { + paused_[output] = true; + paused_count_++; + UpdatePauseStateUnlocked(); + } +} + +// Called from Source nodes +void BackpressureCombiner::Resume(Source* output) { + std::lock_guard lg(mutex_); + if (paused_.find(output) == paused_.end()) { + paused_[output] = false; + UpdatePauseStateUnlocked(); + } else if (paused_[output]) { + paused_[output] = false; + paused_count_--; + UpdatePauseStateUnlocked(); + } +} + +void BackpressureCombiner::Stop() { + std::lock_guard lg(mutex_); + stopped = true; + backpressure_control_->Resume(); + paused = false; +} + +void BackpressureCombiner::UpdatePauseStateUnlocked() { + if (stopped) return; + bool should_be_paused = (paused_count_ > 0); + if (!pause_on_any_) { + should_be_paused = should_be_paused && (paused_count_ == paused_.size()); + } + if (should_be_paused) { + if (!paused) { + backpressure_control_->Pause(); + paused = true; + } + } else { + if (paused) { + backpressure_control_->Resume(); + paused = false; + } + } +} + +BackpressureCombiner::Source::Source(BackpressureCombiner* ctrl) { + if (ctrl) { + AddController(ctrl); + } +} + +void BackpressureCombiner::Source::AddController(BackpressureCombiner* ctrl) { + ctrl->Resume(this); // populate map in controller + connections_.push_back(ctrl); +} +void BackpressureCombiner::Source::Pause() { + for (auto& conn_ : connections_) { + conn_->Pause(this); + } +} +void BackpressureCombiner::Source::Resume() { + for (auto& conn_ : connections_) { + conn_->Resume(this); + } +} + +} // namespace arrow::acero diff --git a/cpp/src/arrow/acero/backpressure.h b/cpp/src/arrow/acero/backpressure.h new file mode 100644 index 00000000000..303ec551070 --- /dev/null +++ b/cpp/src/arrow/acero/backpressure.h @@ -0,0 +1,99 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include "arrow/acero/options.h" + +namespace arrow::acero { + +// Generic backpressure controller for ExecNode +class ARROW_ACERO_EXPORT BackpressureController : public BackpressureControl { + public: + BackpressureController(ExecNode* node, ExecNode* output, + std::atomic& backpressure_counter); + + void Pause() override; + void Resume() override; + + private: + ExecNode* node_; + ExecNode* output_; + std::atomic& backpressure_counter_; +}; + +template +class BackpressureControlWrapper : public BackpressureControl { + public: + explicit BackpressureControlWrapper(T* obj) : obj_(obj) {} + + void Pause() override { obj_->Pause(); } + void Resume() override { obj_->Resume(); } + + private: + T* obj_; +}; + +// Provides infrastructure of combining multiple backpressure sources and propagate the +// result into BackpressureControl There are two logic scheme of backpressure: +// 1. Default pause_on_any=true - pause on any source is propagated - OR logic +// 2. pause_on_any=false - pause is propagated only when all sources are paused - AND +// logic +class ARROW_ACERO_EXPORT BackpressureCombiner { + public: + explicit BackpressureCombiner(std::unique_ptr backpressure_control, + bool pause_on_any = true); + + // Instances of Source can be used as usual BackpresureControl. + // That means that BackpressureCombiner::Source can use another BackpressureCombiner + // as backpressure_control + // This enabled building more complex backpressure logic using AND/OR operations. + // Source can also be connected with more BackpressureCombiners to facilitate + // propagation of backpressure to multiple inputs. + class ARROW_ACERO_EXPORT Source : public BackpressureControl { + public: + // strong - strong_connection=true + // weak - strong_connection=false + explicit Source(BackpressureCombiner* ctrl = nullptr); + void AddController(BackpressureCombiner* ctrl); + void Pause() override; + void Resume() override; + + private: + std::vector connections_; + }; + + void Stop(); + + private: + friend class Source; + void Pause(Source* output); + void Resume(Source* output); + + void UpdatePauseStateUnlocked(); + bool pause_on_any_; + std::unique_ptr backpressure_control_; + std::mutex mutex_; + std::unordered_map paused_; + size_t paused_count_{0}; + bool paused{false}; + bool stopped{false}; +}; + +} // namespace arrow::acero diff --git a/cpp/src/arrow/acero/backpressure_test.cc b/cpp/src/arrow/acero/backpressure_test.cc new file mode 100644 index 00000000000..596a5bec01f --- /dev/null +++ b/cpp/src/arrow/acero/backpressure_test.cc @@ -0,0 +1,128 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#include "arrow/acero/backpressure.h" + +namespace arrow { +namespace acero { + +class MonitorBackpressureControl : public acero::BackpressureControl { + public: + explicit MonitorBackpressureControl(std::atomic& paused) : paused(paused) {} + virtual void Pause() { paused = true; } + virtual void Resume() { paused = false; } + std::atomic& paused; +}; + +TEST(BackpressureCombiner, Basic) { + std::atomic paused{false}; + BackpressureCombiner or_combiner(std::make_unique(paused)); + BackpressureCombiner::Source strong_source1(&or_combiner); + BackpressureCombiner::Source strong_source2; + strong_source2.AddController(&or_combiner); + + BackpressureCombiner and_combiner( + std::make_unique(&or_combiner), + /*pause_on_any=*/false); + BackpressureCombiner::Source weak_source1(&and_combiner); + BackpressureCombiner::Source weak_source2; + weak_source2.AddController(&and_combiner); + + // Any strong causes pause + ASSERT_FALSE(paused); + strong_source1.Pause(); + ASSERT_TRUE(paused); + strong_source2.Pause(); + ASSERT_TRUE(paused); + strong_source1.Resume(); + ASSERT_TRUE(paused); + strong_source2.Resume(); + ASSERT_FALSE(paused); + + // All weak cause pause + ASSERT_FALSE(paused); + weak_source1.Pause(); + ASSERT_FALSE(paused); + weak_source2.Pause(); + ASSERT_TRUE(paused); + weak_source1.Resume(); + ASSERT_FALSE(paused); + weak_source2.Resume(); + ASSERT_FALSE(paused); + + // mixed use + strong_source1.Pause(); + ASSERT_TRUE(paused); + + ASSERT_TRUE(paused); + weak_source1.Pause(); + ASSERT_TRUE(paused); + weak_source2.Pause(); + + strong_source1.Resume(); + ASSERT_TRUE(paused); + + weak_source1.Resume(); + ASSERT_FALSE(paused); + weak_source2.Resume(); + ASSERT_FALSE(paused); +} + +TEST(BackpressureCombiner, OnlyStrong) { + std::atomic paused{false}; + BackpressureCombiner combiner(std::make_unique(paused)); + BackpressureCombiner::Source strong_source1(&combiner); + BackpressureCombiner::Source strong_source2; + strong_source2.AddController(&combiner); + + // Any strong causes pause + ASSERT_FALSE(paused); + strong_source1.Pause(); + ASSERT_TRUE(paused); + strong_source2.Pause(); + ASSERT_TRUE(paused); + strong_source1.Resume(); + ASSERT_TRUE(paused); + strong_source2.Resume(); + ASSERT_FALSE(paused); +} + +TEST(BackpressureCombiner, OnlyWeak) { + std::atomic paused{false}; + BackpressureCombiner combiner(std::make_unique(paused), + false); + + BackpressureCombiner::Source weak_source1(&combiner); + BackpressureCombiner::Source weak_source2; + weak_source2.AddController(&combiner); + + // All weak cause pause + ASSERT_FALSE(paused); + weak_source1.Pause(); + ASSERT_FALSE(paused); + weak_source2.Pause(); + ASSERT_TRUE(paused); + weak_source1.Resume(); + ASSERT_FALSE(paused); + weak_source2.Resume(); + ASSERT_FALSE(paused); +} + +} // namespace acero +} // namespace arrow diff --git a/cpp/src/arrow/acero/exec_plan.cc b/cpp/src/arrow/acero/exec_plan.cc index ff5e5d8bdd9..d006101242c 100644 --- a/cpp/src/arrow/acero/exec_plan.cc +++ b/cpp/src/arrow/acero/exec_plan.cc @@ -1102,6 +1102,24 @@ Result> DeclarationToReader( return DeclarationToReader(std::move(declaration), std::move(options)); } +namespace internal { + +void RegisterSourceNode(ExecFactoryRegistry*); +void RegisterFetchNode(ExecFactoryRegistry*); +void RegisterFilterNode(ExecFactoryRegistry*); +void RegisterOrderByNode(ExecFactoryRegistry*); +void RegisterPivotLongerNode(ExecFactoryRegistry*); +void RegisterProjectNode(ExecFactoryRegistry*); +void RegisterUnionNode(ExecFactoryRegistry*); +void RegisterAggregateNode(ExecFactoryRegistry*); +void RegisterSinkNode(ExecFactoryRegistry*); +void RegisterHashJoinNode(ExecFactoryRegistry*); +void RegisterAsofJoinNode(ExecFactoryRegistry*); +void RegisterSortedMergeNode(ExecFactoryRegistry*); +void RegisterPipeNodes(ExecFactoryRegistry*); + +} // namespace internal + ExecFactoryRegistry* default_exec_factory_registry() { class DefaultRegistry : public ExecFactoryRegistry { public: @@ -1118,6 +1136,7 @@ ExecFactoryRegistry* default_exec_factory_registry() { internal::RegisterHashJoinNode(this); internal::RegisterAsofJoinNode(this); internal::RegisterSortedMergeNode(this); + internal::RegisterPipeNodes(this); } Result GetFactory(const std::string& factory_name) override { diff --git a/cpp/src/arrow/acero/exec_plan_internal.h b/cpp/src/arrow/acero/exec_plan_internal.h index e9fe87b69e9..ab5e1fd6760 100644 --- a/cpp/src/arrow/acero/exec_plan_internal.h +++ b/cpp/src/arrow/acero/exec_plan_internal.h @@ -33,5 +33,6 @@ void RegisterSinkNode(ExecFactoryRegistry*); void RegisterHashJoinNode(ExecFactoryRegistry*); void RegisterAsofJoinNode(ExecFactoryRegistry*); void RegisterSortedMergeNode(ExecFactoryRegistry*); +void RegisterPipeNodes(ExecFactoryRegistry*); } // namespace arrow::acero::internal diff --git a/cpp/src/arrow/acero/meson.build b/cpp/src/arrow/acero/meson.build index c7a8bdb4ca6..64e25f4dfe1 100644 --- a/cpp/src/arrow/acero/meson.build +++ b/cpp/src/arrow/acero/meson.build @@ -51,6 +51,7 @@ arrow_acero_srcs = [ 'groupby_aggregate_node.cc', 'aggregate_internal.cc', 'asof_join_node.cc', + 'backpressure.cc', 'bloom_filter.cc', 'exec_plan.cc', 'fetch_node.cc', @@ -74,6 +75,7 @@ arrow_acero_srcs = [ 'time_series_util.cc', 'tpch_node.cc', 'union_node.cc', + 'pipe_node.cc', 'util.cc', ] @@ -93,6 +95,7 @@ meson.override_dependency('arrow-acero', arrow_acero_dep) arrow_acero_testing_sources = ['test_nodes.cc', 'test_util_internal.cc'] arrow_acero_tests = { + 'backpressure-test': {'sources': ['backpressure_test.cc']}, 'plan-test': {'sources': ['plan_test.cc', 'test_nodes_test.cc']}, 'source-node-test': {'sources': ['source_node_test.cc']}, 'fetch-node-test': {'sources': ['fetch_node_test.cc']}, diff --git a/cpp/src/arrow/acero/options.h b/cpp/src/arrow/acero/options.h index 827e9ea775d..810ff679533 100644 --- a/cpp/src/arrow/acero/options.h +++ b/cpp/src/arrow/acero/options.h @@ -23,6 +23,7 @@ #include #include +#include "arrow/acero/exec_plan.h" #include "arrow/acero/type_fwd.h" #include "arrow/acero/visibility.h" #include "arrow/compute/api_aggregate.h" @@ -868,6 +869,52 @@ class ARROW_ACERO_EXPORT PivotLongerNodeOptions : public ExecNodeOptions { std::vector measurement_field_names; }; +/// \brief a node which implements experimental node that enables multiple sink exec nodes +/// +/// Note, this API is experimental and will change in the future +/// +/// This node forwards each exec batch to its output and also provides number of +/// additional source nodes for additional acero pipelines. +class ARROW_ACERO_EXPORT PipeSourceNodeOptions : public ExecNodeOptions { + public: + PipeSourceNodeOptions(std::string pipe_name, std::shared_ptr output_schema, + Ordering ordering = Ordering::Unordered()) + : pipe_name(std::move(pipe_name)), + output_schema(std::move(output_schema)), + ordering(std::move(ordering)) {} + + /// \brief Pipe name used to match with pipe sink + std::string pipe_name; + + /// \brief Expected schema of data. Validated during initialization. + std::shared_ptr output_schema; + + /// \brief Expected ordering of data. Validated during initialization. + Ordering ordering; +}; + +class ARROW_ACERO_EXPORT PipeSinkNodeOptions : public ExecNodeOptions { + public: + PipeSinkNodeOptions(std::string pipe_name, bool pause_on_any = true, + bool stop_on_any = false) + : pipe_name(std::move(pipe_name)), + pause_on_any(pause_on_any), + stop_on_any(stop_on_any) {} + + /// \brief Pipe name used to match with pipe sources + std::string pipe_name; + + /// \brief pause_on_any controls pausing strategy. If true sink input will be paused + /// when any source is paused. If false sink input will be paused hen all sources are + /// paused + bool pause_on_any; + + /// \brief stop_on_any controls stopping strategy. If true sink input will be stopped + /// when any source is stopped. If false sink input will be stopped hen all sources are + /// stopped + bool stop_on_any; +}; + /// @} } // namespace acero diff --git a/cpp/src/arrow/acero/pipe_node.cc b/cpp/src/arrow/acero/pipe_node.cc new file mode 100644 index 00000000000..e726ddc6935 --- /dev/null +++ b/cpp/src/arrow/acero/pipe_node.cc @@ -0,0 +1,394 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/acero/pipe_node.h" +#include +#include "arrow/acero/exec_plan.h" +#include "arrow/acero/exec_plan_internal.h" +#include "arrow/acero/map_node.h" +#include "arrow/acero/options.h" +#include "arrow/acero/query_context.h" +#include "arrow/compute/api_vector.h" +#include "arrow/compute/exec.h" +#include "arrow/compute/expression.h" +#include "arrow/datum.h" +#include "arrow/result.h" +#include "arrow/util/checked_cast.h" +#include "arrow/util/future.h" +#include "arrow/util/logging.h" +#include "arrow/util/logging_internal.h" +#include "arrow/util/tracing_internal.h" +#include "arrow/util/unreachable.h" +namespace arrow { + +using internal::checked_cast; + +namespace acero { + +namespace { + +struct PipeSourceNode : public PipeSource, public ExecNode { + PipeSourceNode(ExecPlan* plan, std::shared_ptr schema, std::string pipe_name, + Ordering ordering) + : ExecNode(plan, {}, {}, std::move(schema)), + pipe_name_(pipe_name), + ordering_(std::move(ordering)) {} + + static Result Make(ExecPlan* plan, std::vector inputs, + const ExecNodeOptions& options) { + RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, 0, kKindName)); + const auto& pipe_source_options = checked_cast(options); + return plan->EmplaceNode(plan, pipe_source_options.output_schema, + pipe_source_options.pipe_name, + pipe_source_options.ordering); + } + + [[noreturn]] static void NoInputs() { + Unreachable("no inputs; this should never be called"); + } + [[noreturn]] Status InputReceived(ExecNode*, ExecBatch) override { NoInputs(); } + [[noreturn]] Status InputFinished(ExecNode*, int) override { NoInputs(); } + + Status HandleInputReceived(ExecBatch batch) override { + return output_->InputReceived(this, std::move(batch)); + } + Status HandleInputFinished(int total_batches) override { + return output_->InputFinished(this, total_batches); + } + + const Ordering& ordering() const override { return ordering_; } + + Status StartProducing() override { + auto st = PipeSource::Validate(ordering()); + if (!st.ok()) { + return st.WithMessage("Pipe '", pipe_name_, "' error: ", st.message()); + } + return Status::OK(); + } + + void PauseProducing(ExecNode* output, int32_t counter) override { + PipeSource::Pause(counter); + } + void ResumeProducing(ExecNode* output, int32_t counter) override { + PipeSource::Resume(counter); + } + + Status StopProducingImpl() override { return PipeSource::StopProducing(); } + + static const char kKindName[]; + const char* kind_name() const override { return kKindName; } + + const std::string pipe_name_; + Ordering ordering_; +}; + +const char PipeSourceNode::kKindName[] = "PipeSourceNode"; + +class PipeSinkBackpressureControl : public BackpressureControl { + public: + PipeSinkBackpressureControl(ExecNode* node, ExecNode* output) + : node_(node), output_(output) {} + + void Pause() override { node_->PauseProducing(output_, ++backpressure_counter_); } + void Resume() override { node_->ResumeProducing(output_, ++backpressure_counter_); } + + private: + ExecNode* node_; + ExecNode* output_; + std::atomic backpressure_counter_; +}; + +class PipeSinkNode : public ExecNode { + public: + PipeSinkNode(ExecPlan* plan, std::vector inputs, std::string pipe_name, + bool pause_on_any, bool stop_on_any) + : ExecNode(plan, inputs, /*input_labels=*/{pipe_name}, {}) { + pipe_ = std::make_shared( + plan, std::move(pipe_name), + std::make_unique(inputs[0], this), + [this]() { return StopProducing(); }, inputs[0]->ordering(), pause_on_any, + stop_on_any); + } + + static Result Make(ExecPlan* plan, std::vector inputs, + const ExecNodeOptions& options) { + RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, 1, "PipeSinkNode")); + const auto& pipe_tee_options = checked_cast(options); + return plan->EmplaceNode( + plan, std::move(inputs), pipe_tee_options.pipe_name, + pipe_tee_options.pause_on_any, pipe_tee_options.stop_on_any); + } + static const char kKindName[]; + const char* kind_name() const override { return kKindName; } + + Status InputReceived(ExecNode* input, ExecBatch batch) override { + DCHECK_EQ(input, inputs_[0]); + return pipe_->InputReceived(batch); + } + + Status InputFinished(ExecNode* input, int total_batches) override { + DCHECK_EQ(input, inputs_[0]); + return pipe_->InputFinished(total_batches); + } + + Status Init() override { + ARROW_RETURN_NOT_OK(pipe_->Init(inputs_[0]->output_schema(), true)); + return ExecNode::Init(); + } + + Status StartProducing() override { return Status::OK(); } + + // sink nodes have no outputs from which to feel backpressure + [[noreturn]] static void NoOutputs() { + Unreachable("no outputs; this should never be called"); + } + [[noreturn]] void ResumeProducing(ExecNode* output, int32_t counter) override { + NoOutputs(); + } + [[noreturn]] void PauseProducing(ExecNode* output, int32_t counter) override { + NoOutputs(); + } + + protected: + Status StopProducingImpl() override { return Status::OK(); } + + std::string ToStringExtra(int indent = 0) const override { return "pipe_sink"; } + + protected: + std::shared_ptr pipe_; +}; + +const char PipeSinkNode::kKindName[] = "PipeSinkNode"; + +class PipeTeeNode : public PipeSource, public PipeSinkNode { + public: + PipeTeeNode(ExecPlan* plan, std::vector inputs, std::string pipe_name, + bool pause_on_any, bool stop_on_any) + : PipeSinkNode(plan, inputs, pipe_name, pause_on_any, stop_on_any) { + output_schema_ = inputs[0]->output_schema(); + auto st = PipeSinkNode::pipe_->addSyncSource(this); + if (ARROW_PREDICT_FALSE(!st.ok())) { // this should never happen + Unreachable(std::string("PipeTee unexpected error: ") + st.ToString()); + } + } + + static Result Make(ExecPlan* plan, std::vector inputs, + const ExecNodeOptions& options) { + RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, 1, "PipeTeeNode")); + const auto& pipe_tee_options = checked_cast(options); + return plan->EmplaceNode( + plan, std::move(inputs), pipe_tee_options.pipe_name, + pipe_tee_options.pause_on_any, pipe_tee_options.stop_on_any); + } + static const char kKindName[]; + const char* kind_name() const override { return kKindName; } + + const Ordering& ordering() const override { return inputs_[0]->ordering(); } + + Status StartProducing() override { + auto st = PipeSource::Validate(ordering()); + if (!st.ok()) { + return st.WithMessage("Pipe '", PipeSinkNode::pipe_->PipeName(), + "' error: ", st.message()); + } + return PipeSinkNode::StartProducing(); + } + + void PauseProducing(ExecNode* output, int32_t counter) override { + PipeSource::Pause(counter); + } + + void ResumeProducing(ExecNode* output, int32_t counter) override { + PipeSource::Resume(counter); + } + + Status HandleInputReceived(ExecBatch batch) override { + return output_->InputReceived(this, std::move(batch)); + } + Status HandleInputFinished(int total_batches) override { + return output_->InputFinished(this, total_batches); + } + + protected: + Status StopProducingImpl() override { return Status::OK(); } + + std::string ToStringExtra(int indent = 0) const override { return "pipe_tee"; } +}; + +const char PipeTeeNode::kKindName[] = "PipeTeeNode"; + +} // namespace + +PipeSource::PipeSource() {} +Status PipeSource::Initialize(Pipe* pipe) { + if (pipe_) return Status::Invalid("Pipe:" + pipe->PipeName() + " has multiple sinks"); + pipe_ = pipe; + backpressure_source_.AddController(pipe); + return Status::OK(); +} + +void PipeSource::Pause(int32_t counter) { + auto lock = mutex_.Lock(); + if (!stopped) { + if (backpressure_counter < counter) { + backpressure_counter = counter; + backpressure_source_.Pause(); + } + } +} +void PipeSource::Resume(int32_t counter) { + auto lock = mutex_.Lock(); + if (backpressure_counter < counter) { + backpressure_counter = counter; + backpressure_source_.Resume(); + } +} +Status PipeSource::StopProducing() { + if (pipe_) { + { + auto lock = mutex_.Lock(); + stopped = true; + backpressure_source_.Resume(); + } + return pipe_->StopProducing(this); + } + return Status::OK(); +} + +Status PipeSource::Validate(const Ordering& ordering) { + if (!pipe_) { + return Status::Invalid("Pipe does not have sink"); + } + if (!ordering.IsSuborderOf(pipe_->ordering())) + if (!(ordering.is_implicit() && !pipe_->ordering().is_unordered())) + return Status::Invalid( + "Pipe source ordering: ", ordering.ToString(), + " is not suborder of pipe sink: ", pipe_->ordering().ToString()); + + return Status::OK(); +} + +Pipe::Pipe(ExecPlan* plan, std::string pipe_name, + std::unique_ptr ctrl, + std::function stopProducing, Ordering ordering, bool pause_on_any, + bool stop_on_any) + : BackpressureCombiner(std::move(ctrl), pause_on_any), + plan_(plan), + ordering_(ordering), + pipe_name_(pipe_name), + stopProducing_(stopProducing), + stop_on_any_(stop_on_any) {} + +const Ordering& Pipe::ordering() const { return ordering_; } + +Status Pipe::StopProducing(PipeSource* output) { + auto lock = mutex_.Lock(); + auto& state = state_[output]; + DCHECK(!state.stopped); + state.stopped = true; + size_t stopped_count = ++stopped_count_; + if (stop_on_any_) { + if (stopped_count == 1) { + return stopProducing_(); + } + } else { + if (stopped_count == CountSources()) { + return stopProducing_(); + } + } + return Status::OK(); +} + +Status Pipe::InputReceived(ExecBatch batch) { + for (auto& source_node : async_nodes_) { + { + auto lock = mutex_.Lock(); + if (state_[source_node].stopped) continue; + } + plan_->query_context()->ScheduleTask( + [source_node, batch]() mutable { + return source_node->HandleInputReceived(batch); + }, + "Pipe::InputReceived"); + } + if (sync_node_) return sync_node_->HandleInputReceived(batch); + + return Status::OK(); +} + +Status Pipe::InputFinished(int total_batches) { + for (auto& source_node : async_nodes_) { + plan_->query_context()->ScheduleTask( + [source_node, total_batches]() { + return source_node->HandleInputFinished(total_batches); + }, + "Pipe::HandleInputFinished"); + } + if (sync_node_) return sync_node_->HandleInputFinished(total_batches); + return Status::OK(); +} + +Status Pipe::addAsyncSource(PipeSource* source, bool may_be_sync) { + ARROW_RETURN_NOT_OK(source->Initialize(this)); + if (may_be_sync && !sync_node_) + sync_node_ = source; + else + async_nodes_.push_back(source); + return Status::OK(); +} +Status Pipe::addSyncSource(PipeSource* source) { + ARROW_RETURN_NOT_OK(source->Initialize(this)); + if (sync_node_) return Status::Invalid("Pipe last source overwrite for " + pipe_name_); + sync_node_ = source; + return Status::OK(); +} + +Status Pipe::Init(const std::shared_ptr schema, bool may_be_sync) { + for (auto node : plan_->nodes()) { + if (node->kind_name() == PipeSourceNode::kKindName) { + PipeSourceNode* pipe_source = checked_cast(node); + if (pipe_source->pipe_name_ == pipe_name_) { + if (!schema->Equals(node->output_schema())) { + return Status::Invalid("Pipe schema does not match for " + pipe_name_); + } + ARROW_RETURN_NOT_OK(addAsyncSource(pipe_source, may_be_sync)); + } + } + } + return Status::OK(); +} + +bool Pipe::HasSources() const { return sync_node_ || !async_nodes_.empty(); } + +size_t Pipe::CountSources() const { + auto count = async_nodes_.size(); + if (sync_node_) { + count++; + } + return count; +} + +namespace internal { +void RegisterPipeNodes(ExecFactoryRegistry* registry) { + DCHECK_OK(registry->AddFactory("pipe_source", PipeSourceNode::Make)); + DCHECK_OK(registry->AddFactory("pipe_sink", PipeSinkNode::Make)); + DCHECK_OK(registry->AddFactory("pipe_tee", PipeTeeNode::Make)); +} + +} // namespace internal +} // namespace acero +} // namespace arrow diff --git a/cpp/src/arrow/acero/pipe_node.h b/cpp/src/arrow/acero/pipe_node.h new file mode 100644 index 00000000000..2e31826b3e7 --- /dev/null +++ b/cpp/src/arrow/acero/pipe_node.h @@ -0,0 +1,131 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once +#include +#include "arrow/acero/backpressure.h" +#include "arrow/acero/exec_plan.h" +#include "arrow/acero/options.h" +#include "arrow/acero/visibility.h" +#include "arrow/util/mutex.h" + +namespace arrow { + +using arrow::util::Mutex; +using compute::Ordering; +using internal::checked_cast; + +namespace acero { +class Pipe; +class PipeSource { + public: + PipeSource(); + virtual ~PipeSource() {} + void Pause(int32_t counter); + void Resume(int32_t counter); + Status StopProducing(); + Status Validate(const Ordering& ordering); + + private: + friend class Pipe; + Status Initialize(Pipe* pipe); + + virtual Status HandleInputReceived(ExecBatch batch) = 0; + virtual Status HandleInputFinished(int total_batches) = 0; + + Pipe* pipe_{nullptr}; + BackpressureCombiner::Source backpressure_source_; + Mutex mutex_; + int backpressure_counter{0}; + std::atomic stopped{false}; +}; + +/// @brief Provides pipe like infastructure for Acero. It isa center element for +/// pipe_sink/pipe_tee and pipe_source infrastructure. Can also be used to create +/// auxiliarty outputs(pipes) for ExecNodes. +class ARROW_ACERO_EXPORT Pipe : public BackpressureCombiner { + public: + Pipe(ExecPlan* plan, std::string pipe_name, std::unique_ptr ctrl, + std::function stopProducing, Ordering ordering = Ordering::Unordered(), + bool pause_on_any = true, bool stop_on_any = false); + + const Ordering& ordering() const; + + // Scan current exec plan to find all pipe_source nodes matching pipe name. + // All matching pipe_sources will be added to async sources list. + // If may_be_sync is true first added source will be added to as sync source. + // This mthod sould be called from ExecNode::Init() containing this pipe + Status Init(const std::shared_ptr schema, bool may_be_sync = false); + + /// @brief Adds asynchronous PipeSource + /// @param source - pipe source to be added + /// Batches will be delivered in separate tasks for each + Status addAsyncSource(PipeSource* source, bool may_be_sync = false); + + /// @brief Adds synchronous PipeSource + /// @param source - pipe source to be added + /// Batches will be delivered synchronousely in InputReceived. Only one synchronous + /// source is supported - this should be last operation within task. + Status addSyncSource(PipeSource* source); + + /// @brief Transfer batch to pipe + Status InputReceived(ExecBatch batch); + /// @brief Mark the inputs finished after the given number of batches. + Status InputFinished(int total_batches); + + /// @brief Check for pipe_sources connected + /// Can be used to skip production of exec batches when there are no consumers. + bool HasSources() const; + + /// @brief Check for number of pipe_sources connected + size_t CountSources() const; + + /// @brief Get pipe_name + std::string PipeName() const { return pipe_name_; } + + private: + friend class PipeSource; + + struct SourceState { + bool stopped{false}; + }; + + // Backpresurre interface for PipeSource + void Pause(PipeSource* output, int counter); + // Backpresurre interface for PipeSource + void Resume(PipeSource* output, int counter); + // + Status StopProducing(PipeSource* output); + + private: + ExecPlan* plan_; + Ordering ordering_; + std::string pipe_name_; + std::vector async_nodes_; + PipeSource* sync_node_{NULLPTR}; + // backpressure + std::unordered_map state_; + Mutex mutex_; + // stopProducing + std::atomic_size_t stopped_count_{0}; + std::function stopProducing_; + + const bool stop_on_any_; +}; + +} // namespace acero +} // namespace arrow diff --git a/cpp/src/arrow/acero/pipe_node_test.cc b/cpp/src/arrow/acero/pipe_node_test.cc new file mode 100644 index 00000000000..07b4325302e --- /dev/null +++ b/cpp/src/arrow/acero/pipe_node_test.cc @@ -0,0 +1,386 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#include +#include + +#include "arrow/acero/exec_plan.h" +#include "arrow/acero/options.h" +#include "arrow/acero/test_nodes.h" +#include "arrow/acero/test_util_internal.h" +#include "arrow/acero/util.h" +#include "arrow/compute/exec.h" +#include "arrow/compute/expression.h" +#include "arrow/compute/test_util_internal.h" +#include "arrow/io/util_internal.h" +#include "arrow/record_batch.h" +#include "arrow/table.h" +#include "arrow/testing/future_util.h" +#include "arrow/testing/generator.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/testing/matchers.h" +#include "arrow/testing/random.h" +#include "arrow/util/async_generator.h" +#include "arrow/util/config.h" +#include "arrow/util/logging.h" +#include "arrow/util/macros.h" +#include "arrow/util/thread_pool.h" +#include "arrow/util/vector.h" + +using testing::HasSubstr; + +namespace arrow { + +using compute::call; +using compute::ExecBatchFromJSON; +using compute::field_ref; + +namespace acero { + +void CheckFinishesCancelledOrOk(const Future<>& fut) { + // There is a race condition with most tests that cancel plans. If the + // cancel call comes in too slowly then the plan might have already finished + // ok. + ASSERT_TRUE(fut.Wait(kDefaultAssertFinishesWaitSeconds)); + if (!fut.status().ok()) { + ASSERT_TRUE(fut.status().IsCancelled()); + } +} + +TEST(ExecPlanExecution, PipeErrorSink) { + auto basic_data = MakeBasicBatches(); + + AsyncGenerator> main_sink_gen; + AsyncGenerator> dup_sink_gen; + AsyncGenerator> dup1_sink_gen; + Declaration decl = Declaration::Sequence( + {{"source", SourceNodeOptions{basic_data.schema, basic_data.gen(/*parallel=*/false, + /*slow=*/false)}}, + {"pipe_tee", PipeSinkNodeOptions{"named_pipe_1"}}, + {"sink", SinkNodeOptions{&main_sink_gen}}}); + + Declaration dup = Declaration::Sequence( + {{"pipe_source", PipeSourceNodeOptions{"named_pipe_1", basic_data.schema}}, + {"sink", SinkNodeOptions{&dup_sink_gen}}}); + + Declaration dup1 = Declaration::Sequence( + {{"pipe_source", PipeSourceNodeOptions{"named_pipe_1", basic_data.schema}}, + {"filter", FilterNodeOptions{equal(field_ref("i32"), literal("6"))}}, + {"sink", SinkNodeOptions{&dup1_sink_gen}}}); + + auto bad_schema = schema({field("i32", uint32()), field("bool", boolean())}); + Declaration dup2 = Declaration::Sequence( + {{"pipe_source", PipeSourceNodeOptions{"named_pipe_1", bad_schema}}, + {"filter", FilterNodeOptions{equal(field_ref("i32"), literal(6))}}, + {"sink", SinkNodeOptions{&dup1_sink_gen}}}); + dup1.label = "dup1"; + + // fail on planning + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); + ASSERT_OK(decl.AddToPlan(plan.get())); + ASSERT_OK(dup.AddToPlan(plan.get())); + ASSERT_THAT( + dup1.AddToPlan(plan.get()).status(), + Raises(StatusCode::NotImplemented, + HasSubstr( + "Function 'equal' has no kernel matching input types (int32, string)"))); + // fail on exec + ASSERT_OK_AND_ASSIGN(plan, ExecPlan::Make()); + ASSERT_OK(decl.AddToPlan(plan.get())); + ASSERT_OK(dup.AddToPlan(plan.get())); + ASSERT_OK(dup2.AddToPlan(plan.get())); + ASSERT_OK(plan->Validate()); + std::vector>> gens = { + main_sink_gen, dup_sink_gen, dup1_sink_gen}; + plan->StartProducing(); + ASSERT_THAT(plan->finished().result().status(), + Raises(StatusCode::Invalid, + HasSubstr("Pipe schema does not match for named_pipe_1"))); + + Declaration dup3 = Declaration::Sequence( + {{"pipe_source", + PipeSourceNodeOptions{"named_pipe_missing_sink", basic_data.schema}}, + {"sink", SinkNodeOptions{&dup_sink_gen}}}); + + ASSERT_OK_AND_ASSIGN(plan, ExecPlan::Make()); + ASSERT_OK(decl.AddToPlan(plan.get())); + ASSERT_OK(dup.AddToPlan(plan.get())); + ASSERT_OK(dup3.AddToPlan(plan.get())); + ASSERT_OK(plan->Validate()); + gens = {main_sink_gen, dup_sink_gen, dup1_sink_gen}; + plan->StartProducing(); + ASSERT_THAT( + plan->finished().result().status(), + Raises(StatusCode::Invalid, + HasSubstr("Pipe 'named_pipe_missing_sink' error: Pipe does not have sink"))); +} + +TEST(ExecPlanExecution, PipeErrorDuplicateSink) { + auto basic_data = MakeBasicBatches(); + + AsyncGenerator> main_sink_gen; + AsyncGenerator> dup_sink_gen; + Declaration decl = Declaration::Sequence( + {{"source", SourceNodeOptions{basic_data.schema, basic_data.gen(/*parallel=*/false, + /*slow=*/false)}}, + {"pipe_tee", PipeSinkNodeOptions{"named_pipe_1"}}, + {"pipe_tee", PipeSinkNodeOptions{"named_pipe_1"}}, + {"sink", SinkNodeOptions{&main_sink_gen}}}); + + Declaration dup = Declaration::Sequence( + {{"pipe_source", PipeSourceNodeOptions{"named_pipe_1", basic_data.schema}}, + {"sink", SinkNodeOptions{&dup_sink_gen}}}); + + // fail on planning + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); + ASSERT_OK(decl.AddToPlan(plan.get())); + ASSERT_OK(dup.AddToPlan(plan.get())); + plan->StartProducing(); + ASSERT_THAT( + plan->finished().result().status(), + Raises(StatusCode::Invalid, HasSubstr("Pipe:named_pipe_1 has multiple sinks"))); +} + +TEST(ExecPlanExecution, PipeFilterSink) { + auto basic_data = MakeBasicBatches(); + AsyncGenerator> main_sink_gen; + AsyncGenerator> dup_sink_gen; + AsyncGenerator> dup1_sink_gen; + Declaration decl = Declaration::Sequence( + {{"source", SourceNodeOptions{basic_data.schema, basic_data.gen(/*parallel=*/false, + /*slow=*/false)}}, + {"pipe_tee", PipeSinkNodeOptions{"named_pipe_1"}}, + {"filter", FilterNodeOptions{equal(field_ref("i32"), literal(6))}}, + {"sink", SinkNodeOptions{&main_sink_gen}}}); + + Declaration dup = Declaration::Sequence( + {{"pipe_source", PipeSourceNodeOptions{"named_pipe_1", basic_data.schema}}, + {"sink", SinkNodeOptions{&dup_sink_gen}}}); + + Declaration dup1 = Declaration::Sequence( + {{"pipe_source", PipeSourceNodeOptions{"named_pipe_1", basic_data.schema}}, + {"filter", FilterNodeOptions{and_(greater(field_ref("i32"), literal(3)), + less(field_ref("i32"), literal(6)))}}, + {"sink", SinkNodeOptions{&dup1_sink_gen}}}); + dup1.label = "dup1"; + + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); + ASSERT_OK(decl.AddToPlan(plan.get())); + ASSERT_OK(dup.AddToPlan(plan.get())); + ASSERT_OK(dup1.AddToPlan(plan.get())); + ASSERT_OK(plan->Validate()); + std::vector>> gens = { + main_sink_gen, dup_sink_gen, dup1_sink_gen}; + ASSERT_FINISHES_OK_AND_ASSIGN(auto exec_batches_vec, StartAndCollect(plan.get(), gens)); + + ASSERT_EQ(exec_batches_vec.size(), 3); + + auto exp_batches = {ExecBatchFromJSON({int32(), boolean()}, "[]"), + ExecBatchFromJSON({int32(), boolean()}, "[[6, false]]")}; + auto exp_dup1 = {ExecBatchFromJSON({int32(), boolean()}, "[[4, false]]"), + ExecBatchFromJSON({int32(), boolean()}, "[[5, null]]")}; + AssertExecBatchesEqualIgnoringOrder(basic_data.schema, exp_batches, + exec_batches_vec[0]); + AssertExecBatchesEqualIgnoringOrder(basic_data.schema, basic_data.batches, + exec_batches_vec[1]); + AssertExecBatchesEqualIgnoringOrder(basic_data.schema, exp_dup1, exec_batches_vec[2]); +} + +TEST(ExecPlanExecution, PipeMultiSchemaSink) { + auto basic_data = MakeBasicBatches(); + AsyncGenerator> main_sink_gen; + AsyncGenerator> dup_sink_gen; + AsyncGenerator> dup1_sink_gen; + + Expression strbool = + call("if_else", {field_ref("bool"), literal("true"), literal("false")}); + + Declaration decl = Declaration::Sequence( + {{"source", SourceNodeOptions{basic_data.schema, basic_data.gen(/*parallel=*/false, + /*slow=*/false)}}, + {"pipe_tee", PipeSinkNodeOptions{"named_pipe_1"}}, + {"project", ProjectNodeOptions({strbool}, {"str"})}, + {"pipe_tee", PipeSinkNodeOptions{"named_pipe_2"}}, + {"sink", SinkNodeOptions{&main_sink_gen}}}); + + std::shared_ptr schema_2 = schema({field("str", utf8())}); + Declaration dup = Declaration::Sequence( + {{"pipe_source", PipeSourceNodeOptions{"named_pipe_1", basic_data.schema}}, + {"sink", SinkNodeOptions{&dup_sink_gen}}}); + + Declaration dup1 = Declaration::Sequence( + {{"pipe_source", PipeSourceNodeOptions{"named_pipe_2", schema_2}}, + {"sink", SinkNodeOptions{&dup1_sink_gen}}}); + dup1.label = "dup1"; + + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); + ASSERT_OK(decl.AddToPlan(plan.get())); + ASSERT_OK(dup.AddToPlan(plan.get())); + ASSERT_OK(dup1.AddToPlan(plan.get())); + ASSERT_OK(plan->Validate()); + + std::vector>> gens = { + main_sink_gen, dup_sink_gen, dup1_sink_gen}; + ASSERT_FINISHES_OK_AND_ASSIGN(auto exec_batches_vec, StartAndCollect(plan.get(), gens)); + + ASSERT_EQ(exec_batches_vec.size(), 3); + + // ASSERT_OK_AND_ASSIGN(auto result, DeclarationToExecBatches(std::move(plan))); + auto exp_batches = {ExecBatchFromJSON({utf8()}, "[[\"true\"],[\"false\"]]"), + ExecBatchFromJSON({utf8()}, "[[\"false\"],[\"false\"],[null]]")}; + AssertExecBatchesEqualIgnoringOrder(schema_2, exp_batches, exec_batches_vec[0]); + AssertExecBatchesEqualIgnoringOrder(basic_data.schema, basic_data.batches, + exec_batches_vec[1]); + AssertExecBatchesEqualIgnoringOrder(schema_2, exp_batches, exec_batches_vec[2]); +} + +TEST(ExecPlanExecution, PipeBackpressure) { + std::optional batch = + ExecBatchFromJSON({int32(), boolean()}, + "[[4, false], [5, null], [6, false], [7, false], [null, true]]"); + constexpr uint32_t kPauseIfAbove = 4; + constexpr uint32_t kResumeIfBelow = 2; + uint32_t pause_if_above_bytes = + kPauseIfAbove * static_cast(batch->TotalBufferSize()); + uint32_t resume_if_below_bytes = + kResumeIfBelow * static_cast(batch->TotalBufferSize()); + EXPECT_OK_AND_ASSIGN(std::shared_ptr plan, ExecPlan::Make()); + PushGenerator> batch_producer; + BackpressureOptions backpressure_options(resume_if_below_bytes, pause_if_above_bytes); + struct SinkDesc { + AsyncGenerator> sink_gen; + BackpressureMonitor* backpressure_monitor; + }; + SinkDesc mainSink; + std::shared_ptr schema_ = schema({field("data", int32())}); + ARROW_EXPECT_OK(acero::Declaration::Sequence( + { + {"source", SourceNodeOptions(schema_, batch_producer)}, + {"pipe_tee", PipeSinkNodeOptions{"named_pipe_1"}}, + {"sink", SinkNodeOptions{&mainSink.sink_gen, /*schema=*/nullptr, + backpressure_options, + &mainSink.backpressure_monitor}}, + }) + .AddToPlan(plan.get())); + + SinkDesc dupSink; + ARROW_EXPECT_OK(acero::Declaration::Sequence( + { + {"pipe_source", PipeSourceNodeOptions{"named_pipe_1", schema_}}, + {"sink", SinkNodeOptions{&dupSink.sink_gen, /*schema=*/nullptr, + backpressure_options, + &dupSink.backpressure_monitor}}, + }) + .AddToPlan(plan.get())); + + SinkDesc dup1Sink; + ARROW_EXPECT_OK( + acero::Declaration::Sequence( + { + {"pipe_source", PipeSourceNodeOptions{"named_pipe_1", schema_}}, + //{"filter", FilterNodeOptions{equal(field_ref("data"), + // literal("6"))}}, + {"filter", FilterNodeOptions{equal(field_ref("data"), literal(6))}}, + {"sink", + SinkNodeOptions{&dup1Sink.sink_gen, /*schema=*/nullptr, + backpressure_options, &dup1Sink.backpressure_monitor}}, + }) + .AddToPlan(plan.get())); + + ASSERT_TRUE(mainSink.backpressure_monitor); + ASSERT_TRUE(dupSink.backpressure_monitor); + ASSERT_TRUE(dup1Sink.backpressure_monitor); + plan->StartProducing(); + auto plan_finished = plan->finished(); + + ASSERT_FALSE(mainSink.backpressure_monitor->is_paused()); + ASSERT_FALSE(dupSink.backpressure_monitor->is_paused()); + ASSERT_FALSE(dup1Sink.backpressure_monitor->is_paused()); + + // Should be able to push kPauseIfAbove batches without triggering back pressure + for (uint32_t i = 0; i < kPauseIfAbove; i++) { + batch_producer.producer().Push(batch); + } + + SleepABit(); + ASSERT_FALSE(mainSink.backpressure_monitor->is_paused()); + ASSERT_FALSE(dupSink.backpressure_monitor->is_paused()); + ASSERT_FALSE(dup1Sink.backpressure_monitor->is_paused()); + + // One more batch should trigger back pressure + batch_producer.producer().Push(batch); + BusyWait(10, [&] { return mainSink.backpressure_monitor->is_paused(); }); + ASSERT_TRUE(mainSink.backpressure_monitor->is_paused()); + BusyWait(10, [&] { return dupSink.backpressure_monitor->is_paused(); }); + ASSERT_TRUE(dupSink.backpressure_monitor->is_paused()); + ASSERT_FALSE(dup1Sink.backpressure_monitor->is_paused()); + + // Reading as much as we can while keeping it paused + for (uint32_t i = kPauseIfAbove; i >= kResumeIfBelow; i--) { + ASSERT_FINISHES_OK(mainSink.sink_gen()); + ASSERT_FINISHES_OK(dupSink.sink_gen()); + } + SleepABit(); + ASSERT_TRUE(mainSink.backpressure_monitor->is_paused()); + ASSERT_TRUE(dupSink.backpressure_monitor->is_paused()); + + // Reading one more item should open up backpressure + ASSERT_FINISHES_OK(mainSink.sink_gen()); + BusyWait(10, [&] { return !mainSink.backpressure_monitor->is_paused(); }); + ASSERT_FALSE(mainSink.backpressure_monitor->is_paused()); + + ASSERT_FINISHES_OK(dupSink.sink_gen()); + BusyWait(10, [&] { return !dupSink.backpressure_monitor->is_paused(); }); + ASSERT_FALSE(dupSink.backpressure_monitor->is_paused()); + + for (uint32_t i = 0; i < kResumeIfBelow - 1; i++) { + ASSERT_FINISHES_OK(mainSink.sink_gen()); + ASSERT_FINISHES_OK(dupSink.sink_gen()); + } + SleepABit(); + ASSERT_EQ(mainSink.backpressure_monitor->bytes_in_use(), 0); + ASSERT_EQ(dupSink.backpressure_monitor->bytes_in_use(), 0); + + for (uint32_t i = 0; i < kPauseIfAbove * 2; i++) { + batch_producer.producer().Push(batch); + ASSERT_FINISHES_OK(mainSink.sink_gen()); + ASSERT_FINISHES_OK(dupSink.sink_gen()); + } + SleepABit(); + ASSERT_FALSE(mainSink.backpressure_monitor->is_paused()); + ASSERT_FALSE(dupSink.backpressure_monitor->is_paused()); + ASSERT_FALSE(dup1Sink.backpressure_monitor->is_paused()); + + batch_producer.producer().Push(batch); + ASSERT_FINISHES_OK(mainSink.sink_gen()); + ASSERT_FINISHES_OK(dupSink.sink_gen()); + SleepABit(); + + ASSERT_FALSE(mainSink.backpressure_monitor->is_paused()); + ASSERT_FALSE(dupSink.backpressure_monitor->is_paused()); + ASSERT_TRUE(dup1Sink.backpressure_monitor->is_paused()); + + // Cleanup + batch_producer.producer().Push(IterationEnd>()); + plan->StopProducing(); + CheckFinishesCancelledOrOk(plan->finished()); +} + +} // namespace acero +} // namespace arrow diff --git a/cpp/src/arrow/acero/sorted_merge_node.cc b/cpp/src/arrow/acero/sorted_merge_node.cc index 43f5b7b930a..1b907fdd863 100644 --- a/cpp/src/arrow/acero/sorted_merge_node.cc +++ b/cpp/src/arrow/acero/sorted_merge_node.cc @@ -24,6 +24,7 @@ #include #include +#include "arrow/acero/backpressure.h" #include "arrow/acero/concurrent_queue_internal.h" #include "arrow/acero/exec_plan.h" #include "arrow/acero/exec_plan_internal.h" @@ -82,21 +83,6 @@ using col_index_t = int; constexpr bool kNewTask = true; constexpr bool kPoisonPill = false; -class BackpressureController : public BackpressureControl { - public: - BackpressureController(ExecNode* node, ExecNode* output, - std::atomic& backpressure_counter) - : node_(node), output_(output), backpressure_counter_(backpressure_counter) {} - - void Pause() override { node_->PauseProducing(output_, ++backpressure_counter_); } - void Resume() override { node_->ResumeProducing(output_, ++backpressure_counter_); } - - private: - ExecNode* node_; - ExecNode* output_; - std::atomic& backpressure_counter_; -}; - /// InputState corresponds to an input. Input record batches are queued up in InputState /// until processed and turned into output record batches. class InputState { diff --git a/cpp/src/arrow/acero/test_util_internal.cc b/cpp/src/arrow/acero/test_util_internal.cc index 3e279175810..abf3883c2cd 100644 --- a/cpp/src/arrow/acero/test_util_internal.cc +++ b/cpp/src/arrow/acero/test_util_internal.cc @@ -167,6 +167,44 @@ Future> StartAndCollect( }); } +Future>> StartAndCollect( + ExecPlan* plan, std::vector>> gens) { + RETURN_NOT_OK(plan->Validate()); + plan->StartProducing(); + + auto collected_futs = ::arrow::internal::MapVector( + [](AsyncGenerator> gen) { + return CollectAsyncGenerator(gen); + }, + std::move(gens)); + + auto wait_for = ::arrow::internal::MapVector( + [](Future>> fut) { return Future<>(fut); }, + collected_futs); + wait_for.push_back(plan->finished()); + + static_assert( + std::is_same_v>>>>); + + auto r = AllFinished(wait_for).Then([collected_futs]() mutable + -> Result>> { + return ::arrow::internal::MaybeMapVector( + [](Future>> collected_fut) + -> Result> { + ARROW_ASSIGN_OR_RAISE(auto collected, collected_fut.result()); + + static_assert( + std::is_same_v>>); + return ::arrow::internal::MapVector( + [](std::optional batch) { return batch.value_or(ExecBatch()); }, + std::move(collected)); + }, + std::move(collected_futs)); + }); + return r; +} + namespace { Result MakeIntegerBatch(const std::vector>& gens, diff --git a/cpp/src/arrow/acero/test_util_internal.h b/cpp/src/arrow/acero/test_util_internal.h index e8ff9103cd6..df15f19a04e 100644 --- a/cpp/src/arrow/acero/test_util_internal.h +++ b/cpp/src/arrow/acero/test_util_internal.h @@ -89,6 +89,9 @@ Future<> StartAndFinish(ExecPlan* plan); Future> StartAndCollect( ExecPlan* plan, AsyncGenerator> gen); +Future>> StartAndCollect( + ExecPlan* plan, std::vector>> gens); + AsyncGenerator> MakeIntegerBatchGen( const std::vector>& gens, const std::shared_ptr& schema, int num_batches, int batch_size); diff --git a/cpp/src/arrow/util/vector.h b/cpp/src/arrow/util/vector.h index e77d713a44d..42005cea235 100644 --- a/cpp/src/arrow/util/vector.h +++ b/cpp/src/arrow/util/vector.h @@ -124,7 +124,7 @@ Result> MaybeMapVector(Fn&& map, std::vector&& source) { ARROW_RETURN_NOT_OK(MaybeTransform(std::make_move_iterator(source.begin()), std::make_move_iterator(source.end()), std::back_inserter(out), std::forward(map))); - return std::move(out); + return out; } template