diff --git a/dbms/src/Common/BackgroundTask.cpp b/dbms/src/Common/BackgroundTask.cpp new file mode 100644 index 00000000000..16a23535541 --- /dev/null +++ b/dbms/src/Common/BackgroundTask.cpp @@ -0,0 +1,94 @@ +// Copyright 2022 PingCAP, Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include +namespace DB +{ +bool process_mem_usage(double & resident_set) +{ + resident_set = 0.0; + + // 'file' stat seems to give the most reliable results + std::ifstream stat_stream("/proc/self/stat", std::ios_base::in); + // if "/proc/self/stat" is not supported + if (!stat_stream.is_open()) + return false; + + // dummy vars for leading entries in stat that we don't care about + std::string pid, comm, state, ppid, pgrp, session, tty_nr; + std::string tpgid, flags, minflt, cminflt, majflt, cmajflt; + std::string utime, stime, cutime, cstime, priority, nice; + std::string proc_num_threads, itrealvalue, starttime; + UInt64 vsize; + + // the field we want + Int64 rss; + + stat_stream >> pid >> comm >> state >> ppid >> pgrp >> session >> tty_nr + >> tpgid >> flags >> minflt >> cminflt >> majflt >> cmajflt + >> utime >> stime >> cutime >> cstime >> priority >> nice + >> proc_num_threads >> itrealvalue >> starttime >> vsize >> rss; // don't care about the rest + + stat_stream.close(); + + Int64 page_size_kb = sysconf(_SC_PAGE_SIZE) / 1024; // in case x86-64 is configured to use 2MB pages + resident_set = rss * page_size_kb; + return true; +} + +bool isProcStatSupported() +{ + std::ifstream stat_stream("/proc/self/stat", std::ios_base::in); + return stat_stream.is_open(); +} + +void CollectProcInfoBackgroundTask::begin() +{ + std::unique_lock lk(mu); + if (!is_already_begin) + { + if (!isProcStatSupported()) + { + end_fin = true; + return; + } + std::thread t = ThreadFactory::newThread(false, "MemTrackThread", &CollectProcInfoBackgroundTask::memCheckJob, this); + t.detach(); + is_already_begin = true; + } +} + +void CollectProcInfoBackgroundTask::memCheckJob() +{ + double resident_set; + while (!end_syn) + { + process_mem_usage(resident_set); + resident_set *= 1024; // unit: byte + real_rss = static_cast(resident_set); + + usleep(100000); // sleep 100ms + } + end_fin = true; +} + +void CollectProcInfoBackgroundTask::end() +{ + end_syn = true; + while (!end_fin) + usleep(1000); // Just ok since it is called only when TiFlash shutdown. +} +} // namespace DB diff --git a/dbms/src/Common/BackgroundTask.h b/dbms/src/Common/BackgroundTask.h new file mode 100644 index 00000000000..afeacccd3bc --- /dev/null +++ b/dbms/src/Common/BackgroundTask.h @@ -0,0 +1,39 @@ +// Copyright 2022 PingCAP, Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +namespace DB +{ +class CollectProcInfoBackgroundTask +{ +public: + CollectProcInfoBackgroundTask() = default; + ~CollectProcInfoBackgroundTask() + { + end(); + } + void begin(); + + void end(); + +private: + void memCheckJob(); + + std::mutex mu; + bool is_already_begin = false; + std::atomic end_syn{false}, end_fin{false}; +}; +} // namespace DB diff --git a/dbms/src/Common/MPMCQueue.h b/dbms/src/Common/MPMCQueue.h index ce6d4365393..db82cad0ff2 100644 --- a/dbms/src/Common/MPMCQueue.h +++ b/dbms/src/Common/MPMCQueue.h @@ -81,9 +81,13 @@ class MPMCQueue ~MPMCQueue() { - std::unique_lock lock(mu); - for (; read_pos < write_pos; ++read_pos) - destruct(getObj(read_pos)); + drain(); + } + + void finishAndDrain() + { + finish(); + drain(); } // Cannot to use copy/move constructor, @@ -418,6 +422,13 @@ class MPMCQueue obj.~T(); } + void drain() + { + std::unique_lock lock(mu); + for (; read_pos < write_pos; ++read_pos) + destruct(getObj(read_pos)); + } + template ALWAYS_INLINE bool changeStatus(F && action) { diff --git a/dbms/src/Common/MemoryTracker.cpp b/dbms/src/Common/MemoryTracker.cpp index f64881ae35a..da8ecf28e26 100644 --- a/dbms/src/Common/MemoryTracker.cpp +++ b/dbms/src/Common/MemoryTracker.cpp @@ -22,6 +22,7 @@ #include +std::atomic real_rss{0}; MemoryTracker::~MemoryTracker() { if (peak) @@ -46,6 +47,7 @@ MemoryTracker::~MemoryTracker() * then memory usage of 'next' memory trackers will be underestimated, * because amount will be decreased twice (first - here, second - when real 'free' happens). */ + // TODO In future, maybe we can find a better way to handle the "amount > 0" case. if (auto value = amount.load(std::memory_order_relaxed)) free(value); } @@ -80,7 +82,7 @@ void MemoryTracker::alloc(Int64 size, bool check_memory_limit) /// In this case, it doesn't matter. if (unlikely(fault_probability && drand48() < fault_probability)) { - free(size); + amount.fetch_sub(size, std::memory_order_relaxed); DB::FmtBuffer fmt_buf; fmt_buf.append("Memory tracker"); @@ -93,20 +95,33 @@ void MemoryTracker::alloc(Int64 size, bool check_memory_limit) throw DB::TiFlashException(fmt_buf.toString(), DB::Errors::Coprocessor::MemoryLimitExceeded); } - - if (unlikely(current_limit && will_be > current_limit)) + bool is_rss_too_large = (!next.load(std::memory_order_relaxed) && current_limit + && real_rss > current_limit + bytes_rss_larger_than_limit + && will_be > current_limit - (real_rss - current_limit - bytes_rss_larger_than_limit)); + if (is_rss_too_large + || unlikely(current_limit && will_be > current_limit)) { - free(size); + amount.fetch_sub(size, std::memory_order_relaxed); DB::FmtBuffer fmt_buf; fmt_buf.append("Memory limit"); if (description) fmt_buf.fmtAppend(" {}", description); - fmt_buf.fmtAppend(" exceeded: would use {} (attempt to allocate chunk of {} bytes), maximum: {}", - formatReadableSizeWithBinarySuffix(will_be), - size, - formatReadableSizeWithBinarySuffix(current_limit)); + if (!is_rss_too_large) + { // out of memory quota + fmt_buf.fmtAppend(" exceeded caused by 'out of memory quota for data computing' : would use {} for data computing (attempt to allocate chunk of {} bytes), limit of memory for data computing: {}", + formatReadableSizeWithBinarySuffix(will_be), + size, + formatReadableSizeWithBinarySuffix(current_limit)); + } + else + { // RSS too large + fmt_buf.fmtAppend(" exceeded caused by 'RSS(Resident Set Size) much larger than limit' : process memory size would be {} for (attempt to allocate chunk of {} bytes), limit of memory for data computing : {}", + formatReadableSizeWithBinarySuffix(real_rss), + size, + formatReadableSizeWithBinarySuffix(current_limit)); + } throw DB::TiFlashException(fmt_buf.toString(), DB::Errors::Coprocessor::MemoryLimitExceeded); } @@ -116,7 +131,17 @@ void MemoryTracker::alloc(Int64 size, bool check_memory_limit) peak.store(will_be, std::memory_order_relaxed); if (auto * loaded_next = next.load(std::memory_order_relaxed)) - loaded_next->alloc(size, check_memory_limit); + { + try + { + loaded_next->alloc(size, check_memory_limit); + } + catch (...) + { + amount.fetch_sub(size, std::memory_order_relaxed); + std::rethrow_exception(std::current_exception()); + } + } } @@ -130,7 +155,7 @@ void MemoryTracker::free(Int64 size) * Memory usage will be calculated with some error. * NOTE The code is not atomic. Not worth to fix. */ - if (new_amount < 0) + if (new_amount < 0 && !next.load(std::memory_order_relaxed)) // handle it only for root memory_tracker { amount.fetch_sub(new_amount); size += new_amount; @@ -170,7 +195,7 @@ thread_local MemoryTracker * current_memory_tracker = nullptr; namespace CurrentMemoryTracker { -static Int64 MEMORY_TRACER_SUBMIT_THRESHOLD = 8 * 1024 * 1024; // 8 MiB +static Int64 MEMORY_TRACER_SUBMIT_THRESHOLD = 1024 * 1024; // 1 MiB #if __APPLE__ && __clang__ static __thread Int64 local_delta{}; #else diff --git a/dbms/src/Common/MemoryTracker.h b/dbms/src/Common/MemoryTracker.h index c87ec713dda..3f4122edf9f 100644 --- a/dbms/src/Common/MemoryTracker.h +++ b/dbms/src/Common/MemoryTracker.h @@ -19,7 +19,7 @@ #include - +extern std::atomic real_rss; namespace CurrentMetrics { extern const Metric MemoryTracking; @@ -35,6 +35,9 @@ class MemoryTracker std::atomic peak{0}; std::atomic limit{0}; + // How many bytes RSS(Resident Set Size) can be larger than limit(max_memory_usage_for_all_queries). Default: 5GB + Int64 bytes_rss_larger_than_limit = 5368709120; + /// To test exception safety of calling code, memory tracker throws an exception on each memory allocation with specified probability. double fault_probability = 0; @@ -70,6 +73,8 @@ class MemoryTracker Int64 getPeak() const { return peak.load(std::memory_order_relaxed); } + Int64 getLimit() const { return limit.load(std::memory_order_relaxed); } + void setLimit(Int64 limit_) { limit.store(limit_, std::memory_order_relaxed); } /** Set limit if it was not set. @@ -77,6 +82,8 @@ class MemoryTracker */ void setOrRaiseLimit(Int64 value); + void setBytesThatRssLargerThanLimit(Int64 value) { bytes_rss_larger_than_limit = value; } + void setFaultProbability(double value) { fault_probability = value; } /// next should be changed only once: from nullptr to some value. diff --git a/dbms/src/Common/tests/gtest_memtracker.cpp b/dbms/src/Common/tests/gtest_memtracker.cpp new file mode 100644 index 00000000000..d31e7b42df4 --- /dev/null +++ b/dbms/src/Common/tests/gtest_memtracker.cpp @@ -0,0 +1,121 @@ +// Copyright 2022 PingCAP, Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +namespace DB::tests +{ +namespace +{ +class MemTrackerTest : public ::testing::Test +{ +}; + +TEST_F(MemTrackerTest, testBasic) +try +{ + MemoryTracker mem_tracker; + mem_tracker.alloc(1024); + ASSERT_EQ(1024, mem_tracker.get()); + mem_tracker.free(1024); + ASSERT_EQ(0, mem_tracker.get()); +} +CATCH + +TEST_F(MemTrackerTest, testRootAndChild) +try +{ + MemoryTracker root_mem_tracker; + MemoryTracker child_mem_tracker(512); + child_mem_tracker.setNext(&root_mem_tracker); + // alloc 500 + child_mem_tracker.alloc(500); + ASSERT_EQ(500, child_mem_tracker.get()); + ASSERT_EQ(500, root_mem_tracker.get()); + + // alloc 256 base on 500 + bool has_err = false; + try + { + child_mem_tracker.alloc(256); //500 + 256 > limit(512) + } + catch (...) + { + has_err = true; + } + ASSERT_TRUE(has_err); + ASSERT_EQ(500, child_mem_tracker.get()); + ASSERT_EQ(500, root_mem_tracker.get()); + + //free 500 + child_mem_tracker.free(500); + ASSERT_EQ(0, child_mem_tracker.get()); + ASSERT_EQ(0, root_mem_tracker.get()); +} +CATCH + +TEST_F(MemTrackerTest, testRootAndMultipleChild) +try +{ + MemoryTracker root(512); // limit 512 + MemoryTracker child1(512); // limit 512 + MemoryTracker child2(512); // limit 512 + child1.setNext(&root); + child2.setNext(&root); + // alloc 500 on child1 + child1.alloc(500); + ASSERT_EQ(500, child1.get()); + ASSERT_EQ(0, child2.get()); + ASSERT_EQ(500, root.get()); + + + // alloc 500 on child2, should fail + bool has_err = false; + try + { + child2.alloc(500); // root will throw error because of "out of quota" + } + catch (...) + { + has_err = true; + } + ASSERT_TRUE(has_err); + ASSERT_EQ(500, child1.get()); + ASSERT_EQ(0, child2.get()); + ASSERT_EQ(500, root.get()); + + // alloc 10 on child2 + child2.alloc(10); + ASSERT_EQ(500, child1.get()); + ASSERT_EQ(10, child2.get()); + ASSERT_EQ(510, root.get()); + + // free 500 on child1 + child1.free(500); + ASSERT_EQ(0, child1.get()); + ASSERT_EQ(10, child2.get()); + ASSERT_EQ(10, root.get()); + + // free 10 on child2 + child2.free(10); + ASSERT_EQ(0, child1.get()); + ASSERT_EQ(0, child2.get()); + ASSERT_EQ(0, root.get()); +} +CATCH + + +} // namespace +} // namespace DB::tests diff --git a/dbms/src/Flash/Coprocessor/DAGContext.cpp b/dbms/src/Flash/Coprocessor/DAGContext.cpp index 6167090194a..46cb5e3bbf5 100644 --- a/dbms/src/Flash/Coprocessor/DAGContext.cpp +++ b/dbms/src/Flash/Coprocessor/DAGContext.cpp @@ -240,4 +240,8 @@ const SingleTableRegions & DAGContext::getTableRegionsInfoByTableID(Int64 table_ { return tables_regions_info.getTableRegionInfoByTableID(table_id); } +const MPPReceiverSetPtr & DAGContext::getMppReceiverSet() const +{ + return mpp_receiver_set; +} } // namespace DB diff --git a/dbms/src/Flash/Coprocessor/DAGContext.h b/dbms/src/Flash/Coprocessor/DAGContext.h index 93e7edda7e8..af5932ac444 100644 --- a/dbms/src/Flash/Coprocessor/DAGContext.h +++ b/dbms/src/Flash/Coprocessor/DAGContext.h @@ -305,6 +305,7 @@ class DAGContext { mpp_receiver_set = receiver_set; } + const MPPReceiverSetPtr & getMppReceiverSet() const; void addCoprocessorReader(const CoprocessorReaderPtr & coprocessor_reader); std::vector & getCoprocessorReaders(); diff --git a/dbms/src/Flash/Coprocessor/StreamWriter.h b/dbms/src/Flash/Coprocessor/StreamWriter.h index fad403b0726..41b108492f5 100644 --- a/dbms/src/Flash/Coprocessor/StreamWriter.h +++ b/dbms/src/Flash/Coprocessor/StreamWriter.h @@ -54,7 +54,7 @@ struct StreamWriter { ::coprocessor::BatchResponse resp; if (!response.SerializeToString(resp.mutable_data())) - throw Exception("Fail to serialize response, response size: " + std::to_string(response.ByteSizeLong())); + throw Exception("[StreamWriter]Fail to serialize response, response size: " + std::to_string(response.ByteSizeLong())); std::lock_guard lk(write_mutex); if (!writer->Write(resp)) throw Exception("Failed to write resp"); diff --git a/dbms/src/Flash/Coprocessor/StreamingDAGResponseWriter.cpp b/dbms/src/Flash/Coprocessor/StreamingDAGResponseWriter.cpp index 6e70f280e6f..5fa7eb245ab 100644 --- a/dbms/src/Flash/Coprocessor/StreamingDAGResponseWriter.cpp +++ b/dbms/src/Flash/Coprocessor/StreamingDAGResponseWriter.cpp @@ -21,6 +21,7 @@ #include #include #include +#include #include #include @@ -135,62 +136,61 @@ template template void StreamingDAGResponseWriter::encodeThenWriteBlocks( const std::vector & input_blocks, - tipb::SelectResponse & response) const + TrackedSelectResp & response) const { if (dag_context.encode_type == tipb::EncodeType::TypeCHBlock) { if (dag_context.isMPPTask()) /// broadcast data among TiFlash nodes in MPP { - mpp::MPPDataPacket packet; + TrackedMppDataPacket tracked_packet(current_memory_tracker); if constexpr (send_exec_summary_at_last) { - serializeToPacket(packet, response); + tracked_packet.serializeByResponse(response.getResponse()); } if (input_blocks.empty()) { if constexpr (send_exec_summary_at_last) { - writer->write(packet); + writer->write(tracked_packet.getPacket()); } return; } for (const auto & block : input_blocks) { chunk_codec_stream->encode(block, 0, block.rows()); - packet.add_chunks(chunk_codec_stream->getString()); + tracked_packet.addChunk(chunk_codec_stream->getString()); chunk_codec_stream->clear(); } - writer->write(packet); + writer->write(tracked_packet.getPacket()); } else /// passthrough data to a non-TiFlash node, like sending data to TiSpark { - response.set_encode_type(dag_context.encode_type); + response.setEncodeType(dag_context.encode_type); if (input_blocks.empty()) { if constexpr (send_exec_summary_at_last) { - writer->write(response); + writer->write(response.getResponse()); } return; } for (const auto & block : input_blocks) { chunk_codec_stream->encode(block, 0, block.rows()); - auto * dag_chunk = response.add_chunks(); - dag_chunk->set_rows_data(chunk_codec_stream->getString()); + response.addChunk(chunk_codec_stream->getString()); chunk_codec_stream->clear(); } - writer->write(response); + writer->write(response.getResponse()); } } else /// passthrough data to a TiDB node { - response.set_encode_type(dag_context.encode_type); + response.setEncodeType(dag_context.encode_type); if (input_blocks.empty()) { if constexpr (send_exec_summary_at_last) { - writer->write(response); + writer->write(response.getResponse()); } return; } @@ -203,8 +203,7 @@ void StreamingDAGResponseWriter::e { if (current_records_num >= records_per_chunk) { - auto * dag_chunk = response.add_chunks(); - dag_chunk->set_rows_data(chunk_codec_stream->getString()); + response.addChunk(chunk_codec_stream->getString()); chunk_codec_stream->clear(); current_records_num = 0; } @@ -217,11 +216,10 @@ void StreamingDAGResponseWriter::e if (current_records_num > 0) { - auto * dag_chunk = response.add_chunks(); - dag_chunk->set_rows_data(chunk_codec_stream->getString()); + response.addChunk(chunk_codec_stream->getString()); chunk_codec_stream->clear(); } - writer->write(response); + writer->write(response.getResponse()); } } @@ -230,9 +228,9 @@ template template void StreamingDAGResponseWriter::batchWrite() { - tipb::SelectResponse response; + TrackedSelectResp response; if constexpr (send_exec_summary_at_last) - addExecuteSummaries(response, !dag_context.isMPPTask() || dag_context.isRootMPPTask()); + addExecuteSummaries(response.getResponse(), !dag_context.isMPPTask() || dag_context.isRootMPPTask()); if (exchange_type == tipb::ExchangeType::Hash) { partitionAndEncodeThenWriteBlocks(blocks, response); @@ -249,13 +247,13 @@ template template void StreamingDAGResponseWriter::handleExecSummary( const std::vector & input_blocks, - std::vector & packet, + std::vector & packets, tipb::SelectResponse & response) const { if constexpr (send_exec_summary_at_last) { /// Sending the response to only one node, default the first one. - serializeToPacket(packet[0], response); + packets[0].serializeByResponse(response); // No need to send data when blocks are not empty, // because exec_summary will be sent together with blocks. @@ -263,7 +261,7 @@ void StreamingDAGResponseWriter::h { for (auto part_id = 0; part_id < partition_num; ++part_id) { - writer->write(packet[part_id], part_id); + writer->write(packets[part_id].getPacket(), part_id); } } } @@ -273,18 +271,18 @@ template template void StreamingDAGResponseWriter::writePackets( const std::vector & responses_row_count, - std::vector & packets) const + std::vector & packets) const { for (size_t part_id = 0; part_id < packets.size(); ++part_id) { if constexpr (send_exec_summary_at_last) { - writer->write(packets[part_id], part_id); + writer->write(packets[part_id].getPacket(), part_id); } else { if (responses_row_count[part_id] > 0) - writer->write(packets[part_id], part_id); + writer->write(packets[part_id].getPacket(), part_id); } } } @@ -354,12 +352,12 @@ template template void StreamingDAGResponseWriter::partitionAndEncodeThenWriteBlocks( std::vector & input_blocks, - tipb::SelectResponse & response) const + TrackedSelectResp & response) const { static_assert(!enable_fine_grained_shuffle); - std::vector packet(partition_num); + std::vector tracked_packets(partition_num); std::vector responses_row_count(partition_num); - handleExecSummary(input_blocks, packet, response); + handleExecSummary(input_blocks, tracked_packets, response.getResponse()); if (input_blocks.empty()) return; @@ -378,12 +376,12 @@ void StreamingDAGResponseWriter::p dest_block.setColumns(std::move(dest_tbl_cols[part_id])); responses_row_count[part_id] += dest_block.rows(); chunk_codec_stream->encode(dest_block, 0, dest_block.rows()); - packet[part_id].add_chunks(chunk_codec_stream->getString()); + tracked_packets[part_id].addChunk(chunk_codec_stream->getString()); chunk_codec_stream->clear(); } } - writePackets(responses_row_count, packet); + writePackets(responses_row_count, tracked_packets); } /// Hash exchanging data among only TiFlash nodes. Only be called when enable_fine_grained_shuffle is true. @@ -399,12 +397,12 @@ void StreamingDAGResponseWriter::b if constexpr (send_exec_summary_at_last) addExecuteSummaries(response, !dag_context.isMPPTask() || dag_context.isRootMPPTask()); - std::vector packet(partition_num); + std::vector tracked_packets(partition_num); std::vector responses_row_count(partition_num, 0); // fine_grained_shuffle_stream_count is in [0, 1024], and partition_num is uint16_t, so will not overflow. uint32_t bucket_num = partition_num * fine_grained_shuffle_stream_count; - handleExecSummary(blocks, packet, response); + handleExecSummary(blocks, tracked_packets, response); if (!blocks.empty()) { std::vector final_dest_tbl_columns(bucket_num); @@ -441,15 +439,15 @@ void StreamingDAGResponseWriter::b row_count_per_part += dest_block.rows(); chunk_codec_stream->encode(dest_block, 0, dest_block.rows()); - packet[part_id].add_chunks(chunk_codec_stream->getString()); - packet[part_id].add_stream_ids(stream_idx); + tracked_packets[part_id].addChunk(chunk_codec_stream->getString()); + tracked_packets[part_id].packet.add_stream_ids(stream_idx); chunk_codec_stream->clear(); } responses_row_count[part_id] = row_count_per_part; } } - writePackets(responses_row_count, packet); + writePackets(responses_row_count, tracked_packets); blocks.clear(); rows_in_blocks = 0; diff --git a/dbms/src/Flash/Coprocessor/StreamingDAGResponseWriter.h b/dbms/src/Flash/Coprocessor/StreamingDAGResponseWriter.h index cd7559d1e79..1e37090509b 100644 --- a/dbms/src/Flash/Coprocessor/StreamingDAGResponseWriter.h +++ b/dbms/src/Flash/Coprocessor/StreamingDAGResponseWriter.h @@ -23,6 +23,7 @@ #include #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wunused-parameter" +#include #include #include @@ -58,16 +59,16 @@ class StreamingDAGResponseWriter : public DAGResponseWriter void batchWriteFineGrainedShuffle(); template - void encodeThenWriteBlocks(const std::vector & input_blocks, tipb::SelectResponse & response) const; + void encodeThenWriteBlocks(const std::vector & input_blocks, TrackedSelectResp & response) const; template - void partitionAndEncodeThenWriteBlocks(std::vector & input_blocks, tipb::SelectResponse & response) const; + void partitionAndEncodeThenWriteBlocks(std::vector & input_blocks, TrackedSelectResp & response) const; template void handleExecSummary(const std::vector & input_blocks, - std::vector & packet, + std::vector & packet, tipb::SelectResponse & response) const; template - void writePackets(const std::vector & responses_row_count, std::vector & packets) const; + void writePackets(const std::vector & responses_row_count, std::vector & packets) const; Int64 batch_send_min_limit; bool should_send_exec_summary_at_last; /// only one stream needs to sending execution summaries at last. diff --git a/dbms/src/Flash/Mpp/ExchangeReceiver.cpp b/dbms/src/Flash/Mpp/ExchangeReceiver.cpp index c896757c84a..3cdb0c2a184 100644 --- a/dbms/src/Flash/Mpp/ExchangeReceiver.cpp +++ b/dbms/src/Flash/Mpp/ExchangeReceiver.cpp @@ -59,26 +59,27 @@ String getReceiverStateStr(const ExchangeReceiverState & s) template bool pushPacket(size_t source_index, const String & req_info, - MPPDataPacketPtr & packet, + const TrackedMppDataPacketPtr & tracked_packet, const std::vector & msg_channels, LoggerPtr & log) { bool push_succeed = true; const mpp::Error * error_ptr = nullptr; - if (packet->has_error()) - error_ptr = &packet->error(); + auto & packet = tracked_packet->packet; + if (packet.has_error()) + error_ptr = &packet.error(); const String * resp_ptr = nullptr; - if (!packet->data().empty()) - resp_ptr = &packet->data(); + if (!packet.data().empty()) + resp_ptr = &packet.data(); if constexpr (enable_fine_grained_shuffle) { std::vector> chunks(msg_channels.size()); - if (!packet->chunks().empty()) + if (!packet.chunks().empty()) { // Packet not empty. - if (unlikely(packet->stream_ids().empty())) + if (unlikely(packet.stream_ids().empty())) { // Fine grained shuffle is enabled in receiver, but sender didn't. We cannot handle this, so return error. // This can happen when there are old version nodes when upgrading. @@ -90,12 +91,12 @@ bool pushPacket(size_t source_index, } // packet.stream_ids[i] is corresponding to packet.chunks[i], // indicating which stream_id this chunk belongs to. - assert(packet->chunks_size() == packet->stream_ids_size()); + assert(packet.chunks_size() == packet.stream_ids_size()); - for (int i = 0; i < packet->stream_ids_size(); ++i) + for (int i = 0; i < packet.stream_ids_size(); ++i) { - UInt64 stream_id = packet->stream_ids(i) % msg_channels.size(); - chunks[stream_id].push_back(&packet->chunks(i)); + UInt64 stream_id = packet.stream_ids(i) % msg_channels.size(); + chunks[stream_id].push_back(&packet.chunks(i)); } } // Still need to send error_ptr or resp_ptr even if packet.chunks_size() is zero. @@ -107,7 +108,7 @@ bool pushPacket(size_t source_index, std::shared_ptr recv_msg = std::make_shared( source_index, req_info, - packet, + tracked_packet, error_ptr, resp_ptr, std::move(chunks[i])); @@ -123,10 +124,10 @@ bool pushPacket(size_t source_index, } else { - std::vector chunks(packet->chunks_size()); - for (int i = 0; i < packet->chunks_size(); ++i) + std::vector chunks(packet.chunks_size()); + for (int i = 0; i < packet.chunks_size(); ++i) { - chunks[i] = &packet->chunks(i); + chunks[i] = &packet.chunks(i); } if (!(resp_ptr == nullptr && error_ptr == nullptr && chunks.empty())) @@ -134,7 +135,7 @@ bool pushPacket(size_t source_index, std::shared_ptr recv_msg = std::make_shared( source_index, req_info, - packet, + tracked_packet, error_ptr, resp_ptr, std::move(chunks)); @@ -190,7 +191,7 @@ class AsyncRequestHandler : public UnaryCallback { packets.resize(batch_packet_count); for (auto & packet : packets) - packet = std::make_shared(); + packet = std::make_shared(); start(); } @@ -212,7 +213,7 @@ class AsyncRequestHandler : public UnaryCallback case AsyncRequestStage::WAIT_BATCH_READ: if (ok) ++read_packet_index; - if (!ok || read_packet_index == batch_packet_count || packets[read_packet_index - 1]->has_error()) + if (!ok || read_packet_index == batch_packet_count || packets[read_packet_index - 1]->hasError()) notifyReactor(); else reader->read(packets[read_packet_index], thisAsUnaryCallback()); @@ -228,6 +229,7 @@ class AsyncRequestHandler : public UnaryCallback // handle will be called by ExchangeReceiver::reactor. void handle() { + std::string err_info; LOG_FMT_TRACE(log, "stage: {}", stage); switch (stage) { @@ -251,8 +253,8 @@ class AsyncRequestHandler : public UnaryCallback if (auto packet = getErrorPacket()) setDone("Exchange receiver meet error : " + packet->error().msg()); - else if (!sendPackets()) - setDone("Exchange receiver meet error : push packets fail"); + else if (!sendPackets(err_info)) + setDone("Exchange receiver meet error : push packets fail, " + err_info); else if (read_packet_index < batch_packet_count) { stage = AsyncRequestStage::WAIT_FINISH; @@ -314,10 +316,10 @@ class AsyncRequestHandler : public UnaryCallback notify_queue->push(this); } - MPPDataPacketPtr getErrorPacket() const + TrackedMppDataPacketPtr getErrorPacket() const { // only the last packet may has error, see execute(). - if (read_packet_index != 0 && packets[read_packet_index - 1]->has_error()) + if (read_packet_index != 0 && packets[read_packet_index - 1]->hasError()) return packets[read_packet_index - 1]; return nullptr; } @@ -357,15 +359,31 @@ class AsyncRequestHandler : public UnaryCallback setDone(done_msg); } - bool sendPackets() + bool sendPackets(std::string & err_info) { + // note: no exception should be thrown rudely, since it's called by a GRPC poller. for (size_t i = 0; i < read_packet_index; ++i) { auto & packet = packets[i]; - if (!pushPacket(request->source_index, req_info, packet, *msg_channels, log)) + // We shouldn't throw error directly, since the caller works in a standalone thread. + try + { + packet->recomputeTrackedMem(); + if (!pushPacket( + request->source_index, + req_info, + packet, + *msg_channels, + log)) + return false; + } + catch (...) + { + err_info = getCurrentExceptionMessage(false); return false; + } // can't reuse packet since it is sent to readers. - packet = std::make_shared(); + packet = std::make_shared(); } return true; } @@ -390,7 +408,7 @@ class AsyncRequestHandler : public UnaryCallback AsyncRequestStage stage = AsyncRequestStage::NEED_INIT; std::shared_ptr reader; - MPPDataPacketPtrs packets; + TrackedMPPDataPacketPtrs packets; size_t read_packet_index = 0; Status finish_status = RPCContext::getStatusOK(); LoggerPtr log; @@ -405,7 +423,8 @@ ExchangeReceiverBase::ExchangeReceiverBase( size_t max_streams_, const String & req_id, const String & executor_id, - uint64_t fine_grained_shuffle_stream_count_) + uint64_t fine_grained_shuffle_stream_count_, + bool setup_conn_manually) : rpc_context(std::move(rpc_context_)) , source_num(source_num_) , max_streams(max_streams_) @@ -431,7 +450,11 @@ ExchangeReceiverBase::ExchangeReceiverBase( msg_channels.push_back(std::make_unique>>(max_buffer_size)); } rpc_context->fillSchema(schema); - setUpConnection(); + if (!setup_conn_manually) + { + // In CH client case, we need setUpConn right now. However, MPPTask will setUpConnection manually after ProcEntry is created. + setUpConnection(); + } } catch (...) { @@ -479,6 +502,8 @@ void ExchangeReceiverBase::close() template void ExchangeReceiverBase::setUpConnection() { + if (thread_count) + return; std::vector async_requests; for (size_t index = 0; index < source_num; ++index) @@ -601,15 +626,20 @@ void ExchangeReceiverBase::readLoop(const Request & req) for (;;) { LOG_FMT_TRACE(log, "begin next "); - MPPDataPacketPtr packet = std::make_shared(); + TrackedMppDataPacketPtr packet = std::make_shared(); bool success = reader->read(packet); if (!success) break; has_data = true; - if (packet->has_error()) + if (packet->hasError()) throw Exception("Exchange receiver meet error : " + packet->error().msg()); - if (!pushPacket(req.source_index, req_info, packet, msg_channels, log)) + if (!pushPacket( + req.source_index, + req_info, + packet, + msg_channels, + log)) { meet_error = true; auto local_state = getState(); @@ -681,9 +711,10 @@ DecodeDetail ExchangeReceiverBase::decodeChunks( if (recv_msg->chunks.empty()) return detail; + auto & packet = recv_msg->packet->packet; // Record total packet size even if fine grained shuffle is enabled. - detail.packet_bytes = recv_msg->packet->ByteSizeLong(); + detail.packet_bytes = packet.ByteSizeLong(); for (const String * chunk : recv_msg->chunks) { diff --git a/dbms/src/Flash/Mpp/ExchangeReceiver.h b/dbms/src/Flash/Mpp/ExchangeReceiver.h index 9213eb76e60..8eed7878545 100644 --- a/dbms/src/Flash/Mpp/ExchangeReceiver.h +++ b/dbms/src/Flash/Mpp/ExchangeReceiver.h @@ -38,7 +38,7 @@ struct ReceivedMessage size_t source_index; String req_info; // shared_ptr is copied to make sure error_ptr, resp_ptr and chunks are valid. - const std::shared_ptr packet; + const std::shared_ptr packet; const mpp::Error * error_ptr; const String * resp_ptr; std::vector chunks; @@ -46,7 +46,7 @@ struct ReceivedMessage // Constructor that move chunks. ReceivedMessage(size_t source_index_, const String & req_info_, - const std::shared_ptr & packet_, + const std::shared_ptr & packet_, const mpp::Error * error_ptr_, const String * resp_ptr_, std::vector && chunks_) @@ -129,10 +129,13 @@ class ExchangeReceiverBase size_t max_streams_, const String & req_id, const String & executor_id, - uint64_t fine_grained_shuffle_stream_count); + uint64_t fine_grained_shuffle_stream_count, + bool setup_conn_manually = false); ~ExchangeReceiverBase(); + void setUpConnection(); + void cancel(); void close(); @@ -166,7 +169,6 @@ class ExchangeReceiverBase private: using Request = typename RPCContext::Request; - void setUpConnection(); // Template argument enable_fine_grained_shuffle will be setup properly in setUpConnection(). template void readLoop(const Request & req); diff --git a/dbms/src/Flash/Mpp/GRPCReceiverContext.cpp b/dbms/src/Flash/Mpp/GRPCReceiverContext.cpp index 310745aa024..236d15f9093 100644 --- a/dbms/src/Flash/Mpp/GRPCReceiverContext.cpp +++ b/dbms/src/Flash/Mpp/GRPCReceiverContext.cpp @@ -63,9 +63,9 @@ struct GrpcExchangePacketReader : public ExchangePacketReader call = std::make_shared>(req.req); } - bool read(MPPDataPacketPtr & packet) override + bool read(TrackedMppDataPacketPtr & packet) override { - return reader->Read(packet.get()); + return packet->read(reader); } ::grpc::Status finish() override @@ -101,9 +101,9 @@ struct AsyncGrpcExchangePacketReader : public AsyncExchangePacketReader callback); } - void read(MPPDataPacketPtr & packet, UnaryCallback * callback) override + void read(TrackedMppDataPacketPtr & packet, UnaryCallback * callback) override { - reader->Read(packet.get(), callback); + packet->read(reader, callback); } void finish(::grpc::Status & status, UnaryCallback * callback) override @@ -131,9 +131,9 @@ struct LocalExchangePacketReader : public ExchangePacketReader } } - bool read(MPPDataPacketPtr & packet) override + bool read(TrackedMppDataPacketPtr & packet) override { - MPPDataPacketPtr tmp_packet = local_tunnel_sender->readForLocal(); + TrackedMppDataPacketPtr tmp_packet = local_tunnel_sender->readForLocal(); bool success = tmp_packet != nullptr; if (success) packet = tmp_packet; diff --git a/dbms/src/Flash/Mpp/GRPCReceiverContext.h b/dbms/src/Flash/Mpp/GRPCReceiverContext.h index 3868271ff8a..a6e9afb059b 100644 --- a/dbms/src/Flash/Mpp/GRPCReceiverContext.h +++ b/dbms/src/Flash/Mpp/GRPCReceiverContext.h @@ -28,14 +28,14 @@ namespace DB { using MPPDataPacket = mpp::MPPDataPacket; -using MPPDataPacketPtr = std::shared_ptr; -using MPPDataPacketPtrs = std::vector; +using TrackedMppDataPacketPtr = std::shared_ptr; +using TrackedMPPDataPacketPtrs = std::vector; class ExchangePacketReader { public: virtual ~ExchangePacketReader() = default; - virtual bool read(MPPDataPacketPtr & packet) = 0; + virtual bool read(TrackedMppDataPacketPtr & packet) = 0; virtual ::grpc::Status finish() = 0; }; using ExchangePacketReaderPtr = std::shared_ptr; @@ -45,7 +45,7 @@ class AsyncExchangePacketReader public: virtual ~AsyncExchangePacketReader() = default; virtual void init(UnaryCallback * callback) = 0; - virtual void read(MPPDataPacketPtr & packet, UnaryCallback * callback) = 0; + virtual void read(TrackedMppDataPacketPtr & packet, UnaryCallback * callback) = 0; virtual void finish(::grpc::Status & status, UnaryCallback * callback) = 0; }; using AsyncExchangePacketReaderPtr = std::shared_ptr; diff --git a/dbms/src/Flash/Mpp/MPPReceiverSet.cpp b/dbms/src/Flash/Mpp/MPPReceiverSet.cpp index fd8da091224..c433d8aa9a9 100644 --- a/dbms/src/Flash/Mpp/MPPReceiverSet.cpp +++ b/dbms/src/Flash/Mpp/MPPReceiverSet.cpp @@ -44,6 +44,15 @@ void MPPReceiverSet::cancel() cop_reader->cancel(); } + +void MPPReceiverSet::setUpConnection() +{ + for (auto & it : exchange_receiver_map) + { + it.second->setUpConnection(); + } +} + void MPPReceiverSet::close() { for (auto & it : exchange_receiver_map) diff --git a/dbms/src/Flash/Mpp/MPPReceiverSet.h b/dbms/src/Flash/Mpp/MPPReceiverSet.h index 367fd6859ce..71ae8606e8f 100644 --- a/dbms/src/Flash/Mpp/MPPReceiverSet.h +++ b/dbms/src/Flash/Mpp/MPPReceiverSet.h @@ -28,6 +28,7 @@ class MPPReceiverSet void addExchangeReceiver(const String & executor_id, const ExchangeReceiverPtr & exchange_receiver); void addCoprocessorReader(const CoprocessorReaderPtr & coprocessor_reader); ExchangeReceiverPtr getExchangeReceiver(const String & executor_id) const; + void setUpConnection(); void cancel(); void close(); diff --git a/dbms/src/Flash/Mpp/MPPTask.cpp b/dbms/src/Flash/Mpp/MPPTask.cpp index 1444d2f1963..121352595aa 100644 --- a/dbms/src/Flash/Mpp/MPPTask.cpp +++ b/dbms/src/Flash/Mpp/MPPTask.cpp @@ -201,7 +201,8 @@ void MPPTask::initExchangeReceivers() context->getMaxStreams(), log->identifier(), executor_id, - executor.fine_grained_shuffle_stream_count()); + executor.fine_grained_shuffle_stream_count(), + true); if (status != RUNNING) throw Exception("exchange receiver map can not be initialized, because the task is not in running state"); diff --git a/dbms/src/Flash/Mpp/MPPTunnel.cpp b/dbms/src/Flash/Mpp/MPPTunnel.cpp index b1a25095ba0..7a9f4cdea76 100644 --- a/dbms/src/Flash/Mpp/MPPTunnel.cpp +++ b/dbms/src/Flash/Mpp/MPPTunnel.cpp @@ -68,8 +68,9 @@ MPPTunnel::MPPTunnel( , status(TunnelStatus::Unconnected) , timeout(timeout_) , tunnel_id(tunnel_id_) - , send_queue(std::make_shared>(std::max(5, input_steams_num_ * 5))) // MPMCQueue can benefit from a slightly larger queue size + , send_queue(std::make_shared>(std::max(5, input_steams_num_ * 5))) // MPMCQueue can benefit from a slightly larger queue size , log(Logger::get("MPPTunnel", req_id, tunnel_id)) + , mem_tracker(current_memory_tracker) { RUNTIME_ASSERT(!(is_local_ && is_async_), log, "is_local: {}, is_async: {}.", is_local_, is_async_); if (is_local_) @@ -109,6 +110,10 @@ void MPPTunnel::finishSendQueue() /// exit abnormally, such as being cancelled. void MPPTunnel::close(const String & reason) { + SCOPE_EXIT({ + // ensure the tracked memory is released and udpated before memotry tracker(in ProcListEntry) is released + send_queue->finishAndDrain(); // drain the send_queue when close + }); { std::unique_lock lk(*mu); switch (status) @@ -124,7 +129,7 @@ void MPPTunnel::close(const String & reason) try { FAIL_POINT_TRIGGER_EXCEPTION(FailPoints::exception_during_mpp_close_tunnel); - send_queue->push(std::make_shared(getPacketWithError(reason))); + send_queue->push(std::make_shared(getPacketWithError(reason), mem_tracker)); if (mode == TunnelSenderMode::ASYNC_GRPC) async_tunnel_sender->tryFlushOne(); } @@ -159,7 +164,7 @@ void MPPTunnel::write(const mpp::MPPDataPacket & data, bool close_after_write) if (status == TunnelStatus::Finished) throw Exception(fmt::format("write to tunnel which is already closed,{}", tunnel_sender ? tunnel_sender->getConsumerFinishMsg() : "")); - if (send_queue->push(std::make_shared(data)) == MPMCQueueResult::OK) + if (send_queue->push(std::make_shared(data, mem_tracker)) == MPMCQueueResult::OK) { connection_profile_info.bytes += data.ByteSizeLong(); connection_profile_info.packets += 1; @@ -301,6 +306,11 @@ StringRef MPPTunnel::statusToString() } } +void MPPTunnel::updateMemTracker() +{ + mem_tracker = current_memory_tracker; +} + void TunnelSender::consumerFinish(const String & msg) { LOG_FMT_TRACE(log, "calling consumer Finish"); @@ -321,10 +331,10 @@ void SyncTunnelSender::sendJob() String err_msg; try { - MPPDataPacketPtr res; + TrackedMppDataPacketPtr res; while (send_queue->pop(res) == MPMCQueueResult::OK) { - if (!writer->write(*res)) + if (!writer->write(res->packet)) { err_msg = "grpc writes failed."; break; @@ -379,11 +389,11 @@ void AsyncTunnelSender::sendOne(bool use_lock) bool queue_empty_flag = false; try { - MPPDataPacketPtr res; + TrackedMppDataPacketPtr res; queue_empty_flag = send_queue->pop(res) != MPMCQueueResult::OK; if (!queue_empty_flag) { - if (!writer->write(*res)) + if (!writer->write(res->packet)) { err_msg = "grpc writes failed."; } @@ -413,11 +423,15 @@ void AsyncTunnelSender::sendOne(bool use_lock) } } -LocalTunnelSender::MPPDataPacketPtr LocalTunnelSender::readForLocal() +std::shared_ptr LocalTunnelSender::readForLocal() { - MPPDataPacketPtr res; + TrackedMppDataPacketPtr res; if (send_queue->pop(res) == MPMCQueueResult::OK) + { + // switch tunnel's memory tracker into receiver's + res->switchMemTracker(current_memory_tracker); return res; + } consumerFinish(""); return nullptr; } diff --git a/dbms/src/Flash/Mpp/MPPTunnel.h b/dbms/src/Flash/Mpp/MPPTunnel.h index 5243c9aaf36..e6e6ad94bbd 100644 --- a/dbms/src/Flash/Mpp/MPPTunnel.h +++ b/dbms/src/Flash/Mpp/MPPTunnel.h @@ -33,6 +33,8 @@ #include #pragma GCC diagnostic pop +#include + #include #include #include @@ -61,8 +63,8 @@ enum class TunnelSenderMode class TunnelSender : private boost::noncopyable { public: - using MPPDataPacketPtr = std::shared_ptr; - using DataPacketMPMCQueuePtr = std::shared_ptr>; + using TrackedMppDataPacketPtr = std::shared_ptr; + using DataPacketMPMCQueuePtr = std::shared_ptr>; virtual ~TunnelSender() = default; TunnelSender(TunnelSenderMode mode_, DataPacketMPMCQueuePtr send_queue_, PacketWriter * writer_, const LoggerPtr log_, const String & tunnel_id_) : mode(mode_) @@ -177,7 +179,7 @@ class LocalTunnelSender : public TunnelSender public: using Base = TunnelSender; using Base::Base; - MPPDataPacketPtr readForLocal(); + TrackedMppDataPacketPtr readForLocal(); }; using TunnelSenderPtr = std::shared_ptr; @@ -202,7 +204,7 @@ using LocalTunnelSenderPtr = std::shared_ptr; * To be short: before connect, only close can finish a MPPTunnel; after connect, only Sender Finish can. * * Each MPPTunnel has a Sender to consume data. There're three kinds of senders: sync_remote, local and async_remote. - * + * * The protocol between MPPTunnel and Sender: * - All data will be pushed into the `send_queue`, including errors. * - MPPTunnel may finish `send_queue` to notify Sender normally finish. @@ -259,6 +261,8 @@ class MPPTunnel : private boost::noncopyable const LoggerPtr & getLogger() const { return log; } + void updateMemTracker(); + TunnelSenderPtr getTunnelSender() { return tunnel_sender; } SyncTunnelSenderPtr getSyncTunnelSender() { return sync_tunnel_sender; } AsyncTunnelSenderPtr getAsyncTunnelSender() { return async_tunnel_sender; } @@ -292,11 +296,12 @@ class MPPTunnel : private boost::noncopyable // tunnel id is in the format like "tunnel[sender]+[receiver]" String tunnel_id; - using MPPDataPacketPtr = std::shared_ptr; - using DataPacketMPMCQueuePtr = std::shared_ptr>; + using TrackedMppDataPacketPtr = std::shared_ptr; + using DataPacketMPMCQueuePtr = std::shared_ptr>; DataPacketMPMCQueuePtr send_queue; ConnectionProfileInfo connection_profile_info; const LoggerPtr log; + MemoryTracker * mem_tracker; TunnelSenderMode mode; // Tunnel transfer data mode TunnelSenderPtr tunnel_sender; // Used to refer to one of sync/async/local_tunnel_sender which is not nullptr, just for coding convenience // According to mode value, among the sync/async/local_tunnel_senders, only the responding sender is not null and do actual work diff --git a/dbms/src/Flash/Mpp/MPPTunnelSet.cpp b/dbms/src/Flash/Mpp/MPPTunnelSet.cpp index 3de5af31091..e6610d4d7b8 100644 --- a/dbms/src/Flash/Mpp/MPPTunnelSet.cpp +++ b/dbms/src/Flash/Mpp/MPPTunnelSet.cpp @@ -15,6 +15,7 @@ #include #include #include +#include #include #include @@ -26,13 +27,6 @@ extern const char exception_during_mpp_write_err_to_tunnel[]; } // namespace FailPoints namespace { -inline mpp::MPPDataPacket serializeToPacket(const tipb::SelectResponse & response) -{ - mpp::MPPDataPacket packet; - if (!response.SerializeToString(packet.mutable_data())) - throw Exception(fmt::format("Fail to serialize response, response size: {}", response.ByteSizeLong())); - return packet; -} void checkPacketSize(size_t size) { @@ -57,11 +51,19 @@ void MPPTunnelSetBase::clearExecutionSummaries(tipb::SelectResponse & re } } +template +void MPPTunnelSetBase::updateMemTracker() +{ + for (size_t i = 0; i < tunnels.size(); ++i) + tunnels[i]->updateMemTracker(); +} + template void MPPTunnelSetBase::write(tipb::SelectResponse & response) { - auto packet = serializeToPacket(response); - tunnels[0]->write(packet); + TrackedMppDataPacket tracked_packet; + tracked_packet.serializeByResponse(response); + tunnels[0]->write(tracked_packet.getPacket()); if (tunnels.size() > 1) { @@ -69,10 +71,11 @@ void MPPTunnelSetBase::write(tipb::SelectResponse & response) if (response.execution_summaries_size() > 0) { clearExecutionSummaries(response); - packet = serializeToPacket(response); + tracked_packet = TrackedMppDataPacket(); + tracked_packet.serializeByResponse(response); } for (size_t i = 1; i < tunnels.size(); ++i) - tunnels[i]->write(packet); + tunnels[i]->write(tracked_packet.getPacket()); } } @@ -98,10 +101,11 @@ void MPPTunnelSetBase::write(mpp::MPPDataPacket & packet) template void MPPTunnelSetBase::write(tipb::SelectResponse & response, int16_t partition_id) { + TrackedMppDataPacket tracked_packet; if (partition_id != 0 && response.execution_summaries_size() > 0) clearExecutionSummaries(response); - - tunnels[partition_id]->write(serializeToPacket(response)); + tracked_packet.serializeByResponse(response); + tunnels[partition_id]->write(tracked_packet.getPacket()); } template diff --git a/dbms/src/Flash/Mpp/MPPTunnelSet.h b/dbms/src/Flash/Mpp/MPPTunnelSet.h index e4123db1be5..da37423876e 100644 --- a/dbms/src/Flash/Mpp/MPPTunnelSet.h +++ b/dbms/src/Flash/Mpp/MPPTunnelSet.h @@ -58,6 +58,7 @@ class MPPTunnelSetBase : private boost::noncopyable void close(const String & reason); void finishWrite(); void registerTunnel(const MPPTaskId & id, const TunnelPtr & tunnel); + void updateMemTracker(); TunnelPtr getTunnelByReceiverTaskId(const MPPTaskId & id); diff --git a/dbms/src/Flash/Mpp/TrackedMppDataPacket.h b/dbms/src/Flash/Mpp/TrackedMppDataPacket.h new file mode 100644 index 00000000000..7cb2103d4f3 --- /dev/null +++ b/dbms/src/Flash/Mpp/TrackedMppDataPacket.h @@ -0,0 +1,222 @@ +// Copyright 2022 PingCAP, Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#pragma GCC diagnostic ignored "-Wnon-virtual-dtor" +#ifdef __clang__ +#pragma GCC diagnostic ignored "-Wdeprecated-declarations" +#endif +#include +#include +#include +#include +#include +#pragma GCC diagnostic pop +#include + +#include + +namespace DB +{ +inline size_t estimateAllocatedSize(const mpp::MPPDataPacket & data) +{ + size_t ret = data.data().size(); + for (int i = 0; i < data.chunks_size(); i++) + { + ret += data.chunks(i).size(); + } + return ret; +} + +struct MemTrackerWrapper +{ + MemTrackerWrapper(size_t _size, MemoryTracker * memory_tracker) + : memory_tracker(memory_tracker) + , size(0) + { + alloc(_size); + } + + explicit MemTrackerWrapper(MemoryTracker * memory_tracker) + : memory_tracker(memory_tracker) + , size(0) + {} + + void alloc(size_t delta) + { + if (delta) + { + if (memory_tracker) + { + memory_tracker->alloc(delta); + size += delta; + } + } + } + + void free(size_t delta) + { + if (delta) + { + if (memory_tracker) + { + memory_tracker->free(delta); + size -= delta; + } + } + } + + void switchMemTracker(MemoryTracker * new_memory_tracker) + { + int bak_size = size; + freeAll(); + memory_tracker = new_memory_tracker; + alloc(bak_size); + } + ~MemTrackerWrapper() + { + freeAll(); + } + + void freeAll() + { + free(size); + } + MemoryTracker * memory_tracker; + size_t size = 0; +}; + +struct TrackedMppDataPacket +{ + explicit TrackedMppDataPacket(const mpp::MPPDataPacket & data, MemoryTracker * memory_tracker) + : mem_tracker_wrapper(estimateAllocatedSize(data), memory_tracker) + { + packet = data; + } + + explicit TrackedMppDataPacket() + : mem_tracker_wrapper(current_memory_tracker) + {} + + explicit TrackedMppDataPacket(MemoryTracker * memory_tracker) + : mem_tracker_wrapper(memory_tracker) + {} + + void addChunk(std::string && value) + { + mem_tracker_wrapper.alloc(value.size()); + packet.add_chunks(std::move(value)); + } + + void serializeByResponse(const tipb::SelectResponse & response) + { + mem_tracker_wrapper.alloc(response.ByteSizeLong()); + if (!response.SerializeToString(packet.mutable_data())) + { + mem_tracker_wrapper.free(response.ByteSizeLong()); + throw Exception(fmt::format("Fail to serialize response, response size: {}", response.ByteSizeLong())); + } + } + + void read(const std::unique_ptr<::grpc::ClientAsyncReader<::mpp::MPPDataPacket>> & reader, void * callback) + { + reader->Read(&packet, callback); + need_recompute = true; + //we shouldn't update tracker now, since it's an async reader!! + } + + // we need recompute in some cases we can't update memory counter timely, such as async read + void recomputeTrackedMem() + { + if (need_recompute) + { + mem_tracker_wrapper.freeAll(); + mem_tracker_wrapper.alloc(estimateAllocatedSize(packet)); + need_recompute = false; + } + } + + bool read(const std::unique_ptr<::grpc::ClientReader<::mpp::MPPDataPacket>> & reader) + { + bool ret = reader->Read(&packet); + mem_tracker_wrapper.freeAll(); + mem_tracker_wrapper.alloc(estimateAllocatedSize(packet)); + return ret; + } + + void switchMemTracker(MemoryTracker * new_memory_tracker) + { + mem_tracker_wrapper.switchMemTracker(new_memory_tracker); + } + + bool hasError() const + { + return packet.has_error(); + } + + const ::mpp::Error & error() const + { + return packet.error(); + } + + mpp::MPPDataPacket & getPacket() + { + return packet; + } + + MemTrackerWrapper mem_tracker_wrapper; + mpp::MPPDataPacket packet; + bool need_recompute = false; +}; + +struct TrackedSelectResp +{ + explicit TrackedSelectResp() + : memory_tracker(current_memory_tracker) + {} + + void addChunk(std::string && value) + { + memory_tracker.alloc(value.size()); + auto * dag_chunk = response.add_chunks(); + dag_chunk->set_rows_data(std::move(value)); + } + + tipb::SelectResponse & getResponse() + { + return response; + } + + void setEncodeType(::tipb::EncodeType value) + { + response.set_encode_type(value); + } + + tipb::ExecutorExecutionSummary * addExecutionSummary() + { + return response.add_execution_summaries(); + } + + MemTrackerWrapper memory_tracker; + tipb::SelectResponse response; +}; + +} // namespace DB diff --git a/dbms/src/Flash/Mpp/tests/gtest_mpptunnel.cpp b/dbms/src/Flash/Mpp/tests/gtest_mpptunnel.cpp index ae720badb68..59409e5dd0b 100644 --- a/dbms/src/Flash/Mpp/tests/gtest_mpptunnel.cpp +++ b/dbms/src/Flash/Mpp/tests/gtest_mpptunnel.cpp @@ -77,11 +77,11 @@ struct MockLocalReader { while (true) { - MPPDataPacketPtr tmp_packet = local_sender->readForLocal(); + TrackedMppDataPacketPtr tmp_packet = local_sender->readForLocal(); bool success = tmp_packet != nullptr; if (success) { - write_packet_vec.push_back(tmp_packet->data()); + write_packet_vec.push_back(tmp_packet->packet.data()); } else { @@ -118,7 +118,7 @@ struct MockTerminateLocalReader void read() const { - MPPDataPacketPtr tmp_packet = local_sender->readForLocal(); + TrackedMppDataPacketPtr tmp_packet = local_sender->readForLocal(); local_sender->consumerFinish("Receiver closed"); } }; diff --git a/dbms/src/Flash/tests/bench_exchange.h b/dbms/src/Flash/tests/bench_exchange.h index d8300d45740..cb4d62aeebd 100644 --- a/dbms/src/Flash/tests/bench_exchange.h +++ b/dbms/src/Flash/tests/bench_exchange.h @@ -96,7 +96,7 @@ struct MockReceiverContext { // Not implement benchmark for Async GRPC for now. void init(UnaryCallback *) { assert(0); } - void read(MPPDataPacketPtr &, UnaryCallback *) { assert(0); } + void read(TrackedMppDataPacketPtr &, UnaryCallback *) { assert(0); } void finish(::grpc::Status &, UnaryCallback *) { assert(0); } }; diff --git a/dbms/src/Interpreters/ProcessList.cpp b/dbms/src/Interpreters/ProcessList.cpp index 0f667cfd396..5e50d560bb0 100644 --- a/dbms/src/Interpreters/ProcessList.cpp +++ b/dbms/src/Interpreters/ProcessList.cpp @@ -162,6 +162,7 @@ ProcessList::EntryPtr ProcessList::insert( /// not for specific users, sessions or queries, /// because this setting is effectively global. total_memory_tracker.setOrRaiseLimit(settings.max_memory_usage_for_all_queries); + total_memory_tracker.setBytesThatRssLargerThanLimit(settings.bytes_that_rss_larger_than_limit); total_memory_tracker.setDescription("(total)"); user_process_list.user_memory_tracker.setNext(&total_memory_tracker); } diff --git a/dbms/src/Interpreters/Settings.h b/dbms/src/Interpreters/Settings.h index 6442f9c8dd6..52fec1c4c1a 100644 --- a/dbms/src/Interpreters/Settings.h +++ b/dbms/src/Interpreters/Settings.h @@ -341,6 +341,7 @@ struct Settings M(SettingUInt64, max_memory_usage, 0, "Maximum memory usage for processing of single query. Zero means unlimited.") \ M(SettingUInt64, max_memory_usage_for_user, 0, "Maximum memory usage for processing all concurrently running queries for the user. Zero means unlimited.") \ M(SettingUInt64, max_memory_usage_for_all_queries, 0, "Maximum memory usage for processing all concurrently running queries on the server. Zero means unlimited.") \ + M(SettingUInt64, bytes_that_rss_larger_than_limit, 5368709120, "How many bytes RSS(Resident Set Size) can be larger than limit(max_memory_usage_for_all_queries). Default: 5GB ") \ \ M(SettingUInt64, max_network_bandwidth, 0, "The maximum speed of data exchange over the network in bytes per second for a query. Zero means unlimited.") \ M(SettingUInt64, max_network_bytes, 0, "The maximum number of bytes (compressed) to receive or transmit over the network for execution of the query.") \ diff --git a/dbms/src/Interpreters/executeQuery.cpp b/dbms/src/Interpreters/executeQuery.cpp index 77a3a76d842..fc2103d7838 100644 --- a/dbms/src/Interpreters/executeQuery.cpp +++ b/dbms/src/Interpreters/executeQuery.cpp @@ -22,6 +22,8 @@ #include #include #include +#include +#include #include #include #include @@ -231,6 +233,16 @@ std::tuple executeQueryImpl( context.setProcessListElement(&process_list_entry->get()); } + // Do set-up work for tunnels and receivers after ProcessListEntry is constructed, + // so that we can propagate current_memory_tracker into them. + if (context.getDAGContext()) // When using TiFlash client, dag context will be nullptr in this case. + { + if (context.getDAGContext()->tunnel_set) + context.getDAGContext()->tunnel_set->updateMemTracker(); + if (context.getDAGContext()->getMppReceiverSet()) + context.getDAGContext()->getMppReceiverSet()->setUpConnection(); + } + FAIL_POINT_TRIGGER_EXCEPTION(FailPoints::random_interpreter_failpoint); auto interpreter = query_src.interpreter(context, stage); res = interpreter->execute(); diff --git a/dbms/src/Server/FlashGrpcServerHolder.cpp b/dbms/src/Server/FlashGrpcServerHolder.cpp index 1190985004d..c359db24298 100644 --- a/dbms/src/Server/FlashGrpcServerHolder.cpp +++ b/dbms/src/Server/FlashGrpcServerHolder.cpp @@ -84,6 +84,7 @@ FlashGrpcServerHolder::FlashGrpcServerHolder(Context & context, Poco::Util::Laye : log(log_) , is_shutdown(std::make_shared>(false)) { + background_task.begin(); grpc::ServerBuilder builder; if (security_config.has_tls_config) { @@ -190,6 +191,7 @@ FlashGrpcServerHolder::~FlashGrpcServerHolder() LOG_FMT_INFO(log, "Begin to shut down flash service"); flash_service.reset(); LOG_FMT_INFO(log, "Shut down flash service"); + background_task.end(); } catch (...) { diff --git a/dbms/src/Server/FlashGrpcServerHolder.h b/dbms/src/Server/FlashGrpcServerHolder.h index 57146f40aae..054138de246 100644 --- a/dbms/src/Server/FlashGrpcServerHolder.h +++ b/dbms/src/Server/FlashGrpcServerHolder.h @@ -13,6 +13,7 @@ // limitations under the License. #pragma once +#include #include #include #include @@ -49,6 +50,7 @@ class FlashGrpcServerHolder std::vector> cqs; std::vector> notify_cqs; std::shared_ptr thread_manager; + CollectProcInfoBackgroundTask background_task; }; } // namespace DB \ No newline at end of file