diff --git a/dbms/src/DataStreams/MultiplexInputStream.h b/dbms/src/DataStreams/MultiplexInputStream.h new file mode 100644 index 00000000000..4fa33262e66 --- /dev/null +++ b/dbms/src/DataStreams/MultiplexInputStream.h @@ -0,0 +1,246 @@ +// Copyright 2022 PingCAP, Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include +#include +#include + +namespace DB +{ + +namespace ErrorCodes +{ +extern const int LOGICAL_ERROR; +} // namespace ErrorCodes + +class MultiPartitionStreamPool +{ +public: + MultiPartitionStreamPool() = default; + + void addPartitionStreams(const BlockInputStreams & cur_streams) + { + if (cur_streams.empty()) + return; + std::unique_lock lk(mu); + streams_queue_by_partition.push_back( + std::make_shared>>()); + for (const auto & stream : cur_streams) + streams_queue_by_partition.back()->push(stream); + added_streams.insert(added_streams.end(), cur_streams.begin(), cur_streams.end()); + } + + std::shared_ptr pickOne() + { + std::unique_lock lk(mu); + if (streams_queue_by_partition.empty()) + return nullptr; + if (streams_queue_id >= static_cast(streams_queue_by_partition.size())) + streams_queue_id = 0; + + auto & q = *streams_queue_by_partition[streams_queue_id]; + std::shared_ptr ret = nullptr; + assert(!q.empty()); + ret = q.front(); + q.pop(); + if (q.empty()) + streams_queue_id = removeQueue(streams_queue_id); + else + streams_queue_id = nextQueueId(streams_queue_id); + return ret; + } + + int exportAddedStreams(BlockInputStreams & ret_streams) + { + std::unique_lock lk(mu); + for (auto & stream : added_streams) + ret_streams.push_back(stream); + return added_streams.size(); + } + + int addedStreamsCnt() + { + std::unique_lock lk(mu); + return added_streams.size(); + } + +private: + int removeQueue(int queue_id) + { + streams_queue_by_partition[queue_id] = nullptr; + if (queue_id != static_cast(streams_queue_by_partition.size()) - 1) + { + swap(streams_queue_by_partition[queue_id], streams_queue_by_partition.back()); + streams_queue_by_partition.pop_back(); + return queue_id; + } + else + { + streams_queue_by_partition.pop_back(); + return 0; + } + } + + int nextQueueId(int queue_id) const + { + if (queue_id + 1 < static_cast(streams_queue_by_partition.size())) + return queue_id + 1; + else + return 0; + } + + static void swap(std::shared_ptr>> & a, + std::shared_ptr>> & b) + { + a.swap(b); + } + + std::vector< + std::shared_ptr>>> + streams_queue_by_partition; + std::vector> added_streams; + int streams_queue_id = 0; + std::mutex mu; +}; + +class MultiplexInputStream final : public IProfilingBlockInputStream +{ +private: + static constexpr auto NAME = "Multiplex"; + +public: + MultiplexInputStream( + std::shared_ptr & shared_pool, + const String & req_id) + : log(Logger::get(NAME, req_id)) + , shared_pool(shared_pool) + { + shared_pool->exportAddedStreams(children); + size_t num_children = children.size(); + if (num_children > 1) + { + Block header = children.at(0)->getHeader(); + for (size_t i = 1; i < num_children; ++i) + assertBlocksHaveEqualStructure( + children[i]->getHeader(), + header, + "MULTIPLEX"); + } + } + + String getName() const override { return NAME; } + + ~MultiplexInputStream() override + { + try + { + if (!all_read) + cancel(false); + } + catch (...) + { + tryLogCurrentException(log, __PRETTY_FUNCTION__); + } + } + + /** Different from the default implementation by trying to stop all sources, + * skipping failed by execution. + */ + void cancel(bool kill) override + { + if (kill) + is_killed = true; + + bool old_val = false; + if (!is_cancelled.compare_exchange_strong( + old_val, + true, + std::memory_order_seq_cst, + std::memory_order_relaxed)) + return; + + if (cur_stream) + { + if (IProfilingBlockInputStream * child = dynamic_cast(&*cur_stream)) + { + child->cancel(kill); + } + } + } + + Block getHeader() const override { return children.at(0)->getHeader(); } + +protected: + /// Do nothing, to make the preparation when underlying InputStream is picked from the pool + void readPrefix() override + { + } + + /** The following options are possible: + * 1. `readImpl` function is called until it returns an empty block. + * Then `readSuffix` function is called and then destructor. + * 2. `readImpl` function is called. At some point, `cancel` function is called perhaps from another thread. + * Then `readSuffix` function is called and then destructor. + * 3. At any time, the object can be destroyed (destructor called). + */ + + Block readImpl() override + { + if (all_read) + return {}; + + Block ret; + while (!cur_stream || !(ret = cur_stream->read())) + { + if (cur_stream) + cur_stream->readSuffix(); // release old inputstream + cur_stream = shared_pool->pickOne(); + if (!cur_stream) + { // shared_pool is empty + all_read = true; + return {}; + } + cur_stream->readPrefix(); + } + return ret; + } + + /// Called either after everything is read, or after cancel. + void readSuffix() override + { + if (!all_read && !is_cancelled) + throw Exception("readSuffix called before all data is read", ErrorCodes::LOGICAL_ERROR); + + if (cur_stream) + { + cur_stream->readSuffix(); + cur_stream = nullptr; + } + } + +private: + LoggerPtr log; + + std::shared_ptr shared_pool; + std::shared_ptr cur_stream; + + bool all_read = false; +}; + +} // namespace DB diff --git a/dbms/src/Flash/Coprocessor/DAGStorageInterpreter.cpp b/dbms/src/Flash/Coprocessor/DAGStorageInterpreter.cpp index df7e504d2c4..14cddd94730 100644 --- a/dbms/src/Flash/Coprocessor/DAGStorageInterpreter.cpp +++ b/dbms/src/Flash/Coprocessor/DAGStorageInterpreter.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -634,6 +635,9 @@ void DAGStorageInterpreter::buildLocalStreams(DAGPipeline & pipeline, size_t max if (total_local_region_num == 0) return; const auto table_query_infos = generateSelectQueryInfos(); + bool has_multiple_partitions = table_query_infos.size() > 1; + // MultiPartitionStreamPool will be disabled in no partition mode or single-partition case + std::shared_ptr stream_pool = has_multiple_partitions ? std::make_shared() : nullptr; for (const auto & table_query_info : table_query_infos) { DAGPipeline current_pipeline; @@ -642,9 +646,6 @@ void DAGStorageInterpreter::buildLocalStreams(DAGPipeline & pipeline, size_t max size_t region_num = query_info.mvcc_query_info->regions_query_info.size(); if (region_num == 0) continue; - /// calculate weighted max_streams for each partition, note at least 1 stream is needed for each partition - size_t current_max_streams = table_query_infos.size() == 1 ? max_streams : (max_streams * region_num + total_local_region_num - 1) / total_local_region_num; - QueryProcessingStage::Enum from_stage = QueryProcessingStage::FetchColumns; assert(storages_with_structure_lock.find(table_id) != storages_with_structure_lock.end()); auto & storage = storages_with_structure_lock[table_id].storage; @@ -654,7 +655,7 @@ void DAGStorageInterpreter::buildLocalStreams(DAGPipeline & pipeline, size_t max { try { - current_pipeline.streams = storage->read(required_columns, query_info, context, from_stage, max_block_size, current_max_streams); + current_pipeline.streams = storage->read(required_columns, query_info, context, from_stage, max_block_size, max_streams); // After getting streams from storage, we need to validate whether Regions have changed or not after learner read. // (by calling `validateQueryInfo`). In case the key ranges of Regions have changed (Region merge/split), those `streams` @@ -778,7 +779,19 @@ void DAGStorageInterpreter::buildLocalStreams(DAGPipeline & pipeline, size_t max throw; } } - pipeline.streams.insert(pipeline.streams.end(), current_pipeline.streams.begin(), current_pipeline.streams.end()); + if (has_multiple_partitions) + stream_pool->addPartitionStreams(current_pipeline.streams); + else + pipeline.streams.insert(pipeline.streams.end(), current_pipeline.streams.begin(), current_pipeline.streams.end()); + } + if (has_multiple_partitions) + { + String req_info = dag_context.isMPPTask() ? dag_context.getMPPTaskId().toString() : ""; + int exposed_streams_cnt = std::min(static_cast(max_streams), stream_pool->addedStreamsCnt()); + for (int i = 0; i < exposed_streams_cnt; ++i) + { + pipeline.streams.push_back(std::make_shared(stream_pool, req_info)); + } } }