Skip to content

Commit

Permalink
move memory tracker to mpp task
Browse files Browse the repository at this point in the history
Signed-off-by: xufei <xufeixw@mail.ustc.edu.cn>
  • Loading branch information
windtalker committed Sep 26, 2022
1 parent 6b63a83 commit e403b08
Show file tree
Hide file tree
Showing 20 changed files with 107 additions and 99 deletions.
4 changes: 3 additions & 1 deletion dbms/src/Common/MemoryTracker.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,9 @@ class MemoryTracker : public std::enable_shared_from_this<MemoryTracker>
/// next should be changed only once: from nullptr to some value.
void setNext(MemoryTracker * elem)
{
next.store(elem, std::memory_order_relaxed);
MemoryTracker * old_val = nullptr;
if (!next.compare_exchange_strong(old_val, elem, std::memory_order_seq_cst, std::memory_order_relaxed))
return;
next_holder = elem ? elem->shared_from_this() : nullptr;
}

Expand Down
4 changes: 0 additions & 4 deletions dbms/src/Flash/Coprocessor/DAGContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,4 @@ const SingleTableRegions & DAGContext::getTableRegionsInfoByTableID(Int64 table_
{
return tables_regions_info.getTableRegionInfoByTableID(table_id);
}
const MPPReceiverSetPtr & DAGContext::getMppReceiverSet() const
{
return mpp_receiver_set;
}
} // namespace DB
4 changes: 3 additions & 1 deletion dbms/src/Flash/Coprocessor/DAGContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -305,13 +305,14 @@ class DAGContext
{
mpp_receiver_set = receiver_set;
}
const MPPReceiverSetPtr & getMppReceiverSet() const;
void addCoprocessorReader(const CoprocessorReaderPtr & coprocessor_reader);
std::vector<CoprocessorReaderPtr> & getCoprocessorReaders();

void addSubquery(const String & subquery_id, SubqueryForSet && subquery);
bool hasSubquery() const { return !subqueries.empty(); }
std::vector<SubqueriesForSets> && moveSubqueries() { return std::move(subqueries); }
void setProcessListEntry(std::shared_ptr<ProcessListEntry> entry) { process_list_entry = entry; }
std::shared_ptr<ProcessListEntry> getProcessListEntry() const { return process_list_entry; }

const tipb::DAGRequest * dag_request;
Int64 compile_time_ns = 0;
Expand Down Expand Up @@ -352,6 +353,7 @@ class DAGContext
void initOutputInfo();

private:
std::shared_ptr<ProcessListEntry> process_list_entry;
/// Hold io for correcting the destruction order.
BlockIO io;
/// profile_streams_map is a map that maps from executor_id to profile BlockInputStreams.
Expand Down
5 changes: 5 additions & 0 deletions dbms/src/Flash/EstablishCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,11 @@ void EstablishCallData::trySendOneMsg()
switch (async_tunnel_sender->pop(res, this))
{
case GRPCSendQueueRes::OK:
/// Note: has to switch the memory tracker before `write`
/// because after `write`, `async_tunnel_sender` can be destroyed at any time
/// so there is a risk that `res` is destructed after `aysnc_tunnel_sender`
/// is destructed which may cause the memory tracker in `res` become invalid
res->switchMemTracker(nullptr);
write(res->packet);
return;
case GRPCSendQueueRes::FINISHED:
Expand Down
19 changes: 7 additions & 12 deletions dbms/src/Flash/Mpp/ExchangeReceiver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,10 @@ class AsyncRequestHandler : public UnaryCallback<bool>
}
case AsyncRequestStage::WAIT_BATCH_READ:
if (ok)
{
++read_packet_index;
packets[read_packet_index - 1]->recomputeTrackedMem();
}
if (!ok || read_packet_index == batch_packet_count || packets[read_packet_index - 1]->hasError())
notifyReactor();
else
Expand Down Expand Up @@ -244,7 +247,7 @@ class AsyncRequestHandler : public UnaryCallback<bool>
has_data = true;

if (auto packet = getErrorPacket())
setDone("Exchange receiver meet error : " + packet->error().msg());
setDone("Exchange receiver meet error : " + packet->error());
else if (!sendPackets(err_info))
setDone("Exchange receiver meet error : push packets fail, " + err_info);
else if (read_packet_index < batch_packet_count)
Expand Down Expand Up @@ -360,7 +363,6 @@ class AsyncRequestHandler : public UnaryCallback<bool>
// We shouldn't throw error directly, since the caller works in a standalone thread.
try
{
packet->recomputeTrackedMem();
if (!pushPacket<enable_fine_grained_shuffle, false>(
request->source_index,
req_info,
Expand Down Expand Up @@ -415,8 +417,7 @@ ExchangeReceiverBase<RPCContext>::ExchangeReceiverBase(
size_t max_streams_,
const String & req_id,
const String & executor_id,
uint64_t fine_grained_shuffle_stream_count_,
bool setup_conn_manually)
uint64_t fine_grained_shuffle_stream_count_)
: rpc_context(std::move(rpc_context_))
, source_num(source_num_)
, max_streams(max_streams_)
Expand All @@ -442,11 +443,7 @@ ExchangeReceiverBase<RPCContext>::ExchangeReceiverBase(
msg_channels.push_back(std::make_unique<MPMCQueue<std::shared_ptr<ReceivedMessage>>>(max_buffer_size));
}
rpc_context->fillSchema(schema);
if (!setup_conn_manually)
{
// In CH client case, we need setUpConn right now. However, MPPTask will setUpConnection manually after ProcEntry is created.
setUpConnection();
}
setUpConnection();
}
catch (...)
{
Expand Down Expand Up @@ -494,8 +491,6 @@ void ExchangeReceiverBase<RPCContext>::close()
template <typename RPCContext>
void ExchangeReceiverBase<RPCContext>::setUpConnection()
{
if (thread_count)
return;
mem_tracker = current_memory_tracker ? current_memory_tracker->shared_from_this() : nullptr;
std::vector<Request> async_requests;

Expand Down Expand Up @@ -627,7 +622,7 @@ void ExchangeReceiverBase<RPCContext>::readLoop(const Request & req)
if (packet->hasError())
{
meet_error = true;
local_err_msg = fmt::format("Read error message from mpp packet: {}", packet->getPacket().error().msg());
local_err_msg = fmt::format("Read error message from mpp packet: {}", packet->error());
break;
}

Expand Down
6 changes: 2 additions & 4 deletions dbms/src/Flash/Mpp/ExchangeReceiver.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,13 +129,10 @@ class ExchangeReceiverBase
size_t max_streams_,
const String & req_id,
const String & executor_id,
uint64_t fine_grained_shuffle_stream_count,
bool setup_conn_manually = false);
uint64_t fine_grained_shuffle_stream_count);

~ExchangeReceiverBase();

void setUpConnection();

void cancel();

void close();
Expand Down Expand Up @@ -175,6 +172,7 @@ class ExchangeReceiverBase
void readLoop(const Request & req);
template <bool enable_fine_grained_shuffle>
void reactor(const std::vector<Request> & async_requests);
void setUpConnection();

bool setEndState(ExchangeReceiverState new_state);
String getStatusString();
Expand Down
6 changes: 5 additions & 1 deletion dbms/src/Flash/Mpp/MPPHandler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
#include <Flash/Mpp/MPPHandler.h>
#include <Flash/Mpp/Utils.h>

#include <ext/scope_guard.h>

namespace DB
{
namespace FailPoints
Expand Down Expand Up @@ -44,7 +46,9 @@ void MPPHandler::handleError(const MPPTaskPtr & task, String error)
grpc::Status MPPHandler::execute(const ContextPtr & context, mpp::DispatchTaskResponse * response)
{
MPPTaskPtr task = nullptr;
current_memory_tracker = nullptr; /// to avoid reusing threads in gRPC
SCOPE_EXIT({
current_memory_tracker = nullptr; /// to avoid reusing threads in gRPC
});
try
{
Stopwatch stopwatch;
Expand Down
9 changes: 0 additions & 9 deletions dbms/src/Flash/Mpp/MPPReceiverSet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,6 @@ void MPPReceiverSet::cancel()
cop_reader->cancel();
}


void MPPReceiverSet::setUpConnection()
{
for (auto & it : exchange_receiver_map)
{
it.second->setUpConnection();
}
}

void MPPReceiverSet::close()
{
for (auto & it : exchange_receiver_map)
Expand Down
1 change: 0 additions & 1 deletion dbms/src/Flash/Mpp/MPPReceiverSet.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ class MPPReceiverSet
void addExchangeReceiver(const String & executor_id, const ExchangeReceiverPtr & exchange_receiver);
void addCoprocessorReader(const CoprocessorReaderPtr & coprocessor_reader);
ExchangeReceiverPtr getExchangeReceiver(const String & executor_id) const;
void setUpConnection();
void cancel();
void close();

Expand Down
30 changes: 24 additions & 6 deletions dbms/src/Flash/Mpp/MPPTask.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@
#include <DataStreams/IProfilingBlockInputStream.h>
#include <DataStreams/SquashingBlockOutputStream.h>
#include <Flash/Coprocessor/DAGCodec.h>
#include <Flash/Coprocessor/DAGQuerySource.h>
#include <Flash/Coprocessor/DAGUtils.h>
#include <Flash/Mpp/ExchangeReceiver.h>
#include <Flash/Mpp/GRPCReceiverContext.h>
#include <Flash/Mpp/MPPTask.h>
#include <Flash/Mpp/MPPTunnelSet.h>
#include <Flash/Mpp/Utils.h>
#include <Flash/Planner/PlanQuerySource.h>
#include <Flash/Statistics/traverseExecutors.h>
#include <Flash/executeQuery.h>
#include <Interpreters/ProcessList.h>
Expand Down Expand Up @@ -59,14 +61,16 @@ MPPTask::MPPTask(const mpp::TaskMeta & meta_, const ContextPtr & context_)
, mpp_task_statistics(id, meta.address())
, needed_threads(0)
, schedule_state(ScheduleState::WAITING)
{}
{
current_memory_tracker = nullptr;
}

MPPTask::~MPPTask()
{
/// MPPTask maybe destructed by different thread, set the query memory_tracker
/// to current_memory_tracker in the destructor
if (current_memory_tracker != memory_tracker)
current_memory_tracker = memory_tracker;
if (current_memory_tracker != process_list_entry->get().getMemoryTrackerPtr().get())
current_memory_tracker = process_list_entry->get().getMemoryTrackerPtr().get();
abortTunnels("", true);
if (schedule_state == ScheduleState::SCHEDULED)
{
Expand Down Expand Up @@ -170,8 +174,7 @@ void MPPTask::initExchangeReceivers()
context->getMaxStreams(),
log->identifier(),
executor_id,
executor.fine_grained_shuffle_stream_count(),
true);
executor.fine_grained_shuffle_stream_count());
if (status != RUNNING)
throw Exception("exchange receiver map can not be initialized, because the task is not in running state");

Expand Down Expand Up @@ -281,6 +284,21 @@ void MPPTask::prepare(const mpp::DispatchTaskRequest & task_request)

context->setDAGContext(dag_context.get());

if (context->getSettingsRef().enable_planner)
query_source = std::make_unique<PlanQuerySource>(*context);
else
query_source = std::make_unique<DAGQuerySource>(*context);

auto [query, ast] = query_source->parse(context->getSettingsRef().max_query_size);
process_list_entry = context->getProcessList().insert(
query,
ast.get(),
context->getClientInfo(),
context->getSettingsRef());

context->setProcessListElement(&process_list_entry->get());
dag_context->setProcessListEntry(process_list_entry);

if (dag_context->isRootMPPTask())
{
FAIL_POINT_TRIGGER_EXCEPTION(FailPoints::exception_before_mpp_register_tunnel_for_root_mpp_task);
Expand Down Expand Up @@ -336,6 +354,7 @@ void MPPTask::preprocess()
void MPPTask::runImpl()
{
CPUAffinityManager::getInstance().bindSelfQueryThread();
current_memory_tracker = process_list_entry->get().getMemoryTrackerPtr().get();
if (!switchStatus(INITIALIZING, RUNNING))
{
LOG_WARNING(log, "task not in initializing state, skip running");
Expand All @@ -360,7 +379,6 @@ void MPPTask::runImpl()
scheduleOrWait();

LOG_FMT_INFO(log, "task starts running");
memory_tracker = current_memory_tracker;
if (status.load() != RUNNING)
{
/// when task is in running state, canceling the task will call sendCancelToQuery to do the cancellation, however
Expand Down
5 changes: 4 additions & 1 deletion dbms/src/Flash/Mpp/MPPTask.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <Flash/Mpp/MPPTunnelSet.h>
#include <Flash/Mpp/TaskStatus.h>
#include <Interpreters/Context.h>
#include <Interpreters/IQuerySource.h>
#include <common/logger_useful.h>
#include <common/types.h>
#include <kvproto/mpp.pb.h>
Expand Down Expand Up @@ -123,7 +124,9 @@ class MPPTask : public std::enable_shared_from_this<MPPTask>
// `dag_context` holds inputstreams which could hold ref to `context` so it should be destructed
// before `context`.
std::unique_ptr<DAGContext> dag_context;
MemoryTracker * memory_tracker = nullptr;

std::shared_ptr<ProcessListEntry> process_list_entry;
std::unique_ptr<IQuerySource> query_source;

std::atomic<TaskStatus> status{INITIALIZING};
String err_string;
Expand Down
15 changes: 5 additions & 10 deletions dbms/src/Flash/Mpp/MPPTunnel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ void MPPTunnel::write(const mpp::MPPDataPacket & data)
throw Exception(fmt::format("write to tunnel which is already closed."));
}

if (tunnel_sender->push(std::make_shared<DB::TrackedMppDataPacket>(data, getMemTracker())))
if (tunnel_sender->push(data))
{
connection_profile_info.bytes += data.ByteSizeLong();
connection_profile_info.packets += 1;
Expand Down Expand Up @@ -179,14 +179,14 @@ void MPPTunnel::connect(PacketWriter * writer)
case TunnelSenderMode::LOCAL:
{
RUNTIME_ASSERT(writer == nullptr, log);
local_tunnel_sender = std::make_shared<LocalTunnelSender>(queue_size, log, tunnel_id);
local_tunnel_sender = std::make_shared<LocalTunnelSender>(queue_size, mem_tracker, log, tunnel_id);
tunnel_sender = local_tunnel_sender;
break;
}
case TunnelSenderMode::SYNC_GRPC:
{
RUNTIME_ASSERT(writer != nullptr, log, "Sync writer shouldn't be null");
sync_tunnel_sender = std::make_shared<SyncTunnelSender>(queue_size, log, tunnel_id);
sync_tunnel_sender = std::make_shared<SyncTunnelSender>(queue_size, mem_tracker, log, tunnel_id);
sync_tunnel_sender->startSendThread(writer);
tunnel_sender = sync_tunnel_sender;
break;
Expand Down Expand Up @@ -214,11 +214,11 @@ void MPPTunnel::connectAsync(IAsyncCallData * call_data)
auto kick_func_for_test = call_data->getKickFuncForTest();
if (unlikely(kick_func_for_test.has_value()))
{
async_tunnel_sender = std::make_shared<AsyncTunnelSender>(queue_size, log, tunnel_id, kick_func_for_test.value());
async_tunnel_sender = std::make_shared<AsyncTunnelSender>(queue_size, mem_tracker, log, tunnel_id, kick_func_for_test.value());
}
else
{
async_tunnel_sender = std::make_shared<AsyncTunnelSender>(queue_size, log, tunnel_id, call_data->grpcCall());
async_tunnel_sender = std::make_shared<AsyncTunnelSender>(queue_size, mem_tracker, log, tunnel_id, call_data->grpcCall());
}
call_data->attachAsyncTunnelSender(async_tunnel_sender);
tunnel_sender = async_tunnel_sender;
Expand Down Expand Up @@ -305,11 +305,6 @@ StringRef MPPTunnel::statusToString()
}
}

void MPPTunnel::updateMemTracker()
{
mem_tracker = current_memory_tracker ? current_memory_tracker->shared_from_this() : nullptr;
}

void TunnelSender::consumerFinish(const String & msg)
{
LOG_FMT_TRACE(log, "calling consumer Finish");
Expand Down
Loading

0 comments on commit e403b08

Please sign in to comment.