diff --git a/be/src/vec/runtime/vdata_stream_mgr.cpp b/be/src/vec/runtime/vdata_stream_mgr.cpp index 2a4f4e22861beb..c81fa21fa3406e 100644 --- a/be/src/vec/runtime/vdata_stream_mgr.cpp +++ b/be/src/vec/runtime/vdata_stream_mgr.cpp @@ -150,9 +150,23 @@ Status VDataStreamMgr::transmit_block(const PTransmitDataParams* request, for (int i = 0; i < request->blocks_size(); i++) { std::unique_ptr pblock_ptr = std::make_unique(); pblock_ptr->Swap(const_cast(&request->blocks(i))); + auto pass_done = [&]() -> ::google::protobuf::Closure** { + // If it is eos, no callback is needed, done can be nullptr + if (eos) { + return nullptr; + } + // If it is the last block, a callback is needed, pass done + if (i == request->blocks_size() - 1) { + return done; + } else { + // If it is not the last block, the blocks in the request currently belong to the same queue, + // and the callback is handled by the done of the last block + return nullptr; + } + }; RETURN_IF_ERROR(recvr->add_block( std::move(pblock_ptr), request->sender_id(), request->be_number(), - request->packet_seq() - request->blocks_size() + i, eos ? nullptr : done, + request->packet_seq() - request->blocks_size() + i, pass_done(), wait_for_worker, cpu_time_stop_watch.elapsed_time())); } } diff --git a/be/src/vec/runtime/vdata_stream_mgr.h b/be/src/vec/runtime/vdata_stream_mgr.h index f9b5bbe5bcd802..a9266d02d969ac 100644 --- a/be/src/vec/runtime/vdata_stream_mgr.h +++ b/be/src/vec/runtime/vdata_stream_mgr.h @@ -27,6 +27,7 @@ #include #include +#include "common/be_mock_util.h" #include "common/global_types.h" #include "common/status.h" #include "util/runtime_profile.h" @@ -52,15 +53,16 @@ class VDataStreamRecvr; class VDataStreamMgr { public: VDataStreamMgr(); - ~VDataStreamMgr(); + MOCK_FUNCTION ~VDataStreamMgr(); std::shared_ptr create_recvr( RuntimeState* state, RuntimeProfile::HighWaterMarkCounter* memory_used_counter, const TUniqueId& fragment_instance_id, PlanNodeId dest_node_id, int num_senders, RuntimeProfile* profile, bool is_merging, size_t data_queue_capacity); - Status find_recvr(const TUniqueId& fragment_instance_id, PlanNodeId node_id, - std::shared_ptr* res, bool acquire_lock = true); + MOCK_FUNCTION Status find_recvr(const TUniqueId& fragment_instance_id, PlanNodeId node_id, + std::shared_ptr* res, + bool acquire_lock = true); Status deregister_recvr(const TUniqueId& fragment_instance_id, PlanNodeId node_id); diff --git a/be/src/vec/runtime/vdata_stream_recvr.cpp b/be/src/vec/runtime/vdata_stream_recvr.cpp index 603270b7206282..5bcd7dd1ef9287 100644 --- a/be/src/vec/runtime/vdata_stream_recvr.cpp +++ b/be/src/vec/runtime/vdata_stream_recvr.cpp @@ -489,7 +489,9 @@ void VDataStreamRecvr::close() { } // Remove this receiver from the DataStreamMgr that created it. // TODO: log error msg - static_cast(_mgr->deregister_recvr(fragment_instance_id(), dest_node_id())); + if (_mgr) { + static_cast(_mgr->deregister_recvr(fragment_instance_id(), dest_node_id())); + } _mgr = nullptr; _merger.reset(); diff --git a/be/src/vec/runtime/vdata_stream_recvr.h b/be/src/vec/runtime/vdata_stream_recvr.h index c311668ae824e5..9325f4ada3f724 100644 --- a/be/src/vec/runtime/vdata_stream_recvr.h +++ b/be/src/vec/runtime/vdata_stream_recvr.h @@ -111,7 +111,7 @@ class VDataStreamRecvr : public HasTaskExecutionCtx { // Careful: stream sender will call this function for a local receiver, // accessing members of receiver that are allocated by Object pool // in this function is not safe. - bool exceeds_limit(size_t block_byte_size); + MOCK_FUNCTION bool exceeds_limit(size_t block_byte_size); bool queue_exceeds_limit(size_t byte_size) const; bool is_closed() const { return _is_closed; } diff --git a/be/test/pipeline/exec/vdata_stream_recvr_test.cpp b/be/test/pipeline/exec/vdata_stream_recvr_test.cpp index 636772361c93ad..bed8be3e4f3495 100644 --- a/be/test/pipeline/exec/vdata_stream_recvr_test.cpp +++ b/be/test/pipeline/exec/vdata_stream_recvr_test.cpp @@ -28,6 +28,7 @@ #include "testutil/column_helper.h" #include "testutil/mock/mock_runtime_state.h" #include "vec/data_types/data_type_number.h" +#include "vec/runtime/vdata_stream_mgr.h" namespace doris::pipeline { using namespace vectorized; @@ -37,6 +38,14 @@ struct MockVDataStreamRecvr : public VDataStreamRecvr { RuntimeProfile* profile, int num_senders, bool is_merging) : VDataStreamRecvr(nullptr, counter, state, TUniqueId(), 0, num_senders, is_merging, profile, 1) {}; + + bool exceeds_limit(size_t block_byte_size) override { + if (always_exceeds_limit) { + return true; + } + return VDataStreamRecvr::exceeds_limit(block_byte_size); + } + bool always_exceeds_limit = false; }; class DataStreamRecvrTest : public testing::Test { @@ -50,12 +59,12 @@ class DataStreamRecvrTest : public testing::Test { std::make_unique(TUnit::UNIT, 0, "test"); _mock_state = std::make_unique(); _mock_profile = std::make_unique("test"); - recvr = std::make_unique(_mock_state.get(), _mock_counter.get(), + recvr = std::make_shared(_mock_state.get(), _mock_counter.get(), _mock_profile.get(), num_senders, is_merging); } - std::unique_ptr recvr; + std::shared_ptr recvr; std::unique_ptr _mock_counter; @@ -564,6 +573,49 @@ TEST_F(DataStreamRecvrTest, TestRemoteLocalMultiSender) { input3.join(); output.join(); } + +struct MockVDataStreamMgr : public VDataStreamMgr { + ~MockVDataStreamMgr() override = default; + Status find_recvr(const TUniqueId& fragment_instance_id, PlanNodeId node_id, + std::shared_ptr* res, bool acquire_lock = true) override { + *res = recvr; + return Status::OK(); + } + + std::shared_ptr recvr; +}; + +TEST_F(DataStreamRecvrTest, transmit_block) { + create_recvr(1, true); + recvr->always_exceeds_limit = true; + + MockVDataStreamMgr mgr; + mgr.recvr = recvr; + + MockClosure closure; + closure._cb = [&]() { std::cout << "cb" << std::endl; }; + google::protobuf::Closure* done = &closure; + + PTransmitDataParams request; + { + auto* pblock = request.add_blocks(); + auto block = ColumnHelper::create_block({1, 2, 3, 4, 5}); + to_pblock(block, pblock); + } + + { + auto* pblock = request.add_blocks(); + auto block = ColumnHelper::create_block({1, 2, 3, 4, 5}); + to_pblock(block, pblock); + } + + { + auto st = mgr.transmit_block(&request, &done, 1000); + EXPECT_TRUE(st) << st.msg(); + } + recvr->close(); +} + // ./run-be-ut.sh --run --filter=DataStreamRecvrTest.* } // namespace doris::pipeline