Skip to content

Commit

Permalink
Agg: replace the usage of ThreadPool by ThreadPoolManager. (#4582)
Browse files Browse the repository at this point in the history
close #4581
  • Loading branch information
fuzhe1989 authored Apr 4, 2022
1 parent 661d205 commit e235089
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 14 deletions.
26 changes: 15 additions & 11 deletions dbms/src/Interpreters/Aggregator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1122,10 +1122,14 @@ Block Aggregator::prepareBlockAndFillSingleLevel(AggregatedDataVariants & data_v
}


BlocksList Aggregator::prepareBlocksAndFillTwoLevel(AggregatedDataVariants & data_variants, bool final, ThreadPool * thread_pool) const
BlocksList Aggregator::prepareBlocksAndFillTwoLevel(
AggregatedDataVariants & data_variants,
bool final,
ThreadPoolManager * thread_pool,
size_t max_threads) const
{
#define M(NAME) \
else if (data_variants.type == AggregatedDataVariants::Type::NAME) return prepareBlocksAndFillTwoLevelImpl(data_variants, *data_variants.NAME, final, thread_pool);
else if (data_variants.type == AggregatedDataVariants::Type::NAME) return prepareBlocksAndFillTwoLevelImpl(data_variants, *data_variants.NAME, final, thread_pool, max_threads);

if (false) // NOLINT
{
Expand All @@ -1141,9 +1145,9 @@ BlocksList Aggregator::prepareBlocksAndFillTwoLevelImpl(
AggregatedDataVariants & data_variants,
Method & method,
bool final,
ThreadPool * thread_pool) const
ThreadPoolManager * thread_pool,
size_t max_threads) const
{
size_t max_threads = thread_pool ? thread_pool->size() : 1;
if (max_threads > data_variants.aggregates_pools.size())
for (size_t i = data_variants.aggregates_pools.size(); i < max_threads; ++i)
data_variants.aggregates_pools.push_back(std::make_shared<Arena>());
Expand Down Expand Up @@ -1181,7 +1185,7 @@ BlocksList Aggregator::prepareBlocksAndFillTwoLevelImpl(
[thread_id, &converter] { return converter(thread_id); });

if (thread_pool)
thread_pool->schedule(wrapInvocable(true, [thread_id, &tasks] { tasks[thread_id](); }));
thread_pool->schedule(true, [thread_id, &tasks] { tasks[thread_id](); });
else
tasks[thread_id]();
}
Expand Down Expand Up @@ -1227,10 +1231,10 @@ BlocksList Aggregator::convertToBlocks(AggregatedDataVariants & data_variants, b
if (data_variants.empty())
return blocks;

std::unique_ptr<ThreadPool> thread_pool;
std::shared_ptr<ThreadPoolManager> thread_pool;
if (max_threads > 1 && data_variants.sizeWithoutOverflowRow() > 100000 /// TODO Make a custom threshold.
&& data_variants.isTwoLevel()) /// TODO Use the shared thread pool with the `merge` function.
thread_pool = std::make_unique<ThreadPool>(max_threads);
thread_pool = newThreadPoolManager(max_threads);

if (isCancelled())
return BlocksList();
Expand All @@ -1249,7 +1253,7 @@ BlocksList Aggregator::convertToBlocks(AggregatedDataVariants & data_variants, b
if (!data_variants.isTwoLevel())
blocks.emplace_back(prepareBlockAndFillSingleLevel(data_variants, final));
else
blocks.splice(blocks.end(), prepareBlocksAndFillTwoLevel(data_variants, final, thread_pool.get()));
blocks.splice(blocks.end(), prepareBlocksAndFillTwoLevel(data_variants, final, thread_pool.get(), max_threads));
}

if (!final)
Expand Down Expand Up @@ -2002,10 +2006,10 @@ void Aggregator::mergeStream(const BlockInputStreamPtr & stream, AggregatedDataV
}
};

std::unique_ptr<ThreadPool> thread_pool;
std::shared_ptr<ThreadPoolManager> thread_pool;
if (max_threads > 1 && total_input_rows > 100000 /// TODO Make a custom threshold.
&& has_two_level)
thread_pool = std::make_unique<ThreadPool>(max_threads);
thread_pool = newThreadPoolManager(max_threads);

for (const auto & bucket_blocks : bucket_to_blocks)
{
Expand All @@ -2022,7 +2026,7 @@ void Aggregator::mergeStream(const BlockInputStreamPtr & stream, AggregatedDataV
};

if (thread_pool)
thread_pool->schedule(wrapInvocable(true, task));
thread_pool->schedule(true, task);
else
task();
}
Expand Down
11 changes: 8 additions & 3 deletions dbms/src/Interpreters/Aggregator.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <Common/HashTable/TwoLevelHashMap.h>
#include <Common/HashTable/TwoLevelStringHashMap.h>
#include <Common/Logger.h>
#include <Common/ThreadManager.h>
#include <DataStreams/IBlockInputStream.h>
#include <DataStreams/SizeLimits.h>
#include <Encryption/FileProvider.h>
Expand All @@ -36,7 +37,6 @@
#include <Poco/TemporaryFile.h>
#include <Storages/Transaction/Collator.h>
#include <common/StringRef.h>
#include <common/ThreadPool.h>
#include <common/logger_useful.h>

#include <functional>
Expand Down Expand Up @@ -1048,14 +1048,19 @@ class Aggregator

Block prepareBlockAndFillWithoutKey(AggregatedDataVariants & data_variants, bool final, bool is_overflows) const;
Block prepareBlockAndFillSingleLevel(AggregatedDataVariants & data_variants, bool final) const;
BlocksList prepareBlocksAndFillTwoLevel(AggregatedDataVariants & data_variants, bool final, ThreadPool * thread_pool) const;
BlocksList prepareBlocksAndFillTwoLevel(
AggregatedDataVariants & data_variants,
bool final,
ThreadPoolManager * thread_pool,
size_t max_threads) const;

template <typename Method>
BlocksList prepareBlocksAndFillTwoLevelImpl(
AggregatedDataVariants & data_variants,
Method & method,
bool final,
ThreadPool * thread_pool) const;
ThreadPoolManager * thread_pool,
size_t max_threads) const;

template <bool no_more_keys, typename Method, typename Table>
void mergeStreamsImplCase(
Expand Down

0 comments on commit e235089

Please sign in to comment.