Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion be/src/vec/runtime/vdata_stream_mgr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,23 @@ Status VDataStreamMgr::transmit_block(const PTransmitDataParams* request,
for (int i = 0; i < request->blocks_size(); i++) {
std::unique_ptr<PBlock> pblock_ptr = std::make_unique<PBlock>();
pblock_ptr->Swap(const_cast<PBlock*>(&request->blocks(i)));
auto pass_done = [&]() -> ::google::protobuf::Closure** {
// If it is eos, no callback is needed, done can be nullptr
if (eos) {
return nullptr;
}
// If it is the last block, a callback is needed, pass done
if (i == request->blocks_size() - 1) {
return done;
} else {
// If it is not the last block, the blocks in the request currently belong to the same queue,
// and the callback is handled by the done of the last block
return nullptr;
}
};
RETURN_IF_ERROR(recvr->add_block(
std::move(pblock_ptr), request->sender_id(), request->be_number(),
request->packet_seq() - request->blocks_size() + i, eos ? nullptr : done,
request->packet_seq() - request->blocks_size() + i, pass_done(),
wait_for_worker, cpu_time_stop_watch.elapsed_time()));
}
}
Expand Down
8 changes: 5 additions & 3 deletions be/src/vec/runtime/vdata_stream_mgr.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <unordered_map>
#include <utility>

#include "common/be_mock_util.h"
#include "common/global_types.h"
#include "common/status.h"
#include "util/runtime_profile.h"
Expand All @@ -52,15 +53,16 @@ class VDataStreamRecvr;
class VDataStreamMgr {
public:
VDataStreamMgr();
~VDataStreamMgr();
MOCK_FUNCTION ~VDataStreamMgr();

std::shared_ptr<VDataStreamRecvr> create_recvr(
RuntimeState* state, RuntimeProfile::HighWaterMarkCounter* memory_used_counter,
const TUniqueId& fragment_instance_id, PlanNodeId dest_node_id, int num_senders,
RuntimeProfile* profile, bool is_merging, size_t data_queue_capacity);

Status find_recvr(const TUniqueId& fragment_instance_id, PlanNodeId node_id,
std::shared_ptr<VDataStreamRecvr>* res, bool acquire_lock = true);
MOCK_FUNCTION Status find_recvr(const TUniqueId& fragment_instance_id, PlanNodeId node_id,
std::shared_ptr<VDataStreamRecvr>* res,
bool acquire_lock = true);

Status deregister_recvr(const TUniqueId& fragment_instance_id, PlanNodeId node_id);

Expand Down
4 changes: 3 additions & 1 deletion be/src/vec/runtime/vdata_stream_recvr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,9 @@ void VDataStreamRecvr::close() {
}
// Remove this receiver from the DataStreamMgr that created it.
// TODO: log error msg
static_cast<void>(_mgr->deregister_recvr(fragment_instance_id(), dest_node_id()));
if (_mgr) {
static_cast<void>(_mgr->deregister_recvr(fragment_instance_id(), dest_node_id()));
}
_mgr = nullptr;

_merger.reset();
Expand Down
2 changes: 1 addition & 1 deletion be/src/vec/runtime/vdata_stream_recvr.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ class VDataStreamRecvr : public HasTaskExecutionCtx {
// Careful: stream sender will call this function for a local receiver,
// accessing members of receiver that are allocated by Object pool
// in this function is not safe.
bool exceeds_limit(size_t block_byte_size);
MOCK_FUNCTION bool exceeds_limit(size_t block_byte_size);
bool queue_exceeds_limit(size_t byte_size) const;
bool is_closed() const { return _is_closed; }

Expand Down
56 changes: 54 additions & 2 deletions be/test/pipeline/exec/vdata_stream_recvr_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "testutil/column_helper.h"
#include "testutil/mock/mock_runtime_state.h"
#include "vec/data_types/data_type_number.h"
#include "vec/runtime/vdata_stream_mgr.h"

namespace doris::pipeline {
using namespace vectorized;
Expand All @@ -37,6 +38,14 @@ struct MockVDataStreamRecvr : public VDataStreamRecvr {
RuntimeProfile* profile, int num_senders, bool is_merging)
: VDataStreamRecvr(nullptr, counter, state, TUniqueId(), 0, num_senders, is_merging,
profile, 1) {};

bool exceeds_limit(size_t block_byte_size) override {
if (always_exceeds_limit) {
return true;
}
return VDataStreamRecvr::exceeds_limit(block_byte_size);
}
bool always_exceeds_limit = false;
};

class DataStreamRecvrTest : public testing::Test {
Expand All @@ -50,12 +59,12 @@ class DataStreamRecvrTest : public testing::Test {
std::make_unique<RuntimeProfile::HighWaterMarkCounter>(TUnit::UNIT, 0, "test");
_mock_state = std::make_unique<MockRuntimeState>();
_mock_profile = std::make_unique<RuntimeProfile>("test");
recvr = std::make_unique<MockVDataStreamRecvr>(_mock_state.get(), _mock_counter.get(),
recvr = std::make_shared<MockVDataStreamRecvr>(_mock_state.get(), _mock_counter.get(),
_mock_profile.get(), num_senders,
is_merging);
}

std::unique_ptr<MockVDataStreamRecvr> recvr;
std::shared_ptr<MockVDataStreamRecvr> recvr;

std::unique_ptr<RuntimeProfile::HighWaterMarkCounter> _mock_counter;

Expand Down Expand Up @@ -564,6 +573,49 @@ TEST_F(DataStreamRecvrTest, TestRemoteLocalMultiSender) {
input3.join();
output.join();
}

struct MockVDataStreamMgr : public VDataStreamMgr {
~MockVDataStreamMgr() override = default;
Status find_recvr(const TUniqueId& fragment_instance_id, PlanNodeId node_id,
std::shared_ptr<VDataStreamRecvr>* res, bool acquire_lock = true) override {
*res = recvr;
return Status::OK();
}

std::shared_ptr<VDataStreamRecvr> recvr;
};

TEST_F(DataStreamRecvrTest, transmit_block) {
create_recvr(1, true);
recvr->always_exceeds_limit = true;

MockVDataStreamMgr mgr;
mgr.recvr = recvr;

MockClosure closure;
closure._cb = [&]() { std::cout << "cb" << std::endl; };
google::protobuf::Closure* done = &closure;

PTransmitDataParams request;
{
auto* pblock = request.add_blocks();
auto block = ColumnHelper::create_block<DataTypeInt32>({1, 2, 3, 4, 5});
to_pblock(block, pblock);
}

{
auto* pblock = request.add_blocks();
auto block = ColumnHelper::create_block<DataTypeInt32>({1, 2, 3, 4, 5});
to_pblock(block, pblock);
}

{
auto st = mgr.transmit_block(&request, &done, 1000);
EXPECT_TRUE(st) << st.msg();
}
recvr->close();
}

// ./run-be-ut.sh --run --filter=DataStreamRecvrTest.*

} // namespace doris::pipeline
Loading