Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions programs/local/LocalServer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ namespace ServerSetting
extern const ServerSettingsUInt64 max_prefixes_deserialization_thread_pool_size;
extern const ServerSettingsUInt64 max_prefixes_deserialization_thread_pool_free_size;
extern const ServerSettingsUInt64 prefixes_deserialization_thread_pool_thread_pool_queue_size;
extern const ServerSettingsUInt64 memory_worker_period_ms;
extern const ServerSettingsBool memory_worker_correct_memory_tracker;
}

namespace ErrorCodes
Expand Down Expand Up @@ -236,6 +238,15 @@ void LocalServer::initialize(Poco::Util::Application & self)
});
#endif

#if defined(OS_LINUX)
memory_worker = std::make_unique<MemoryWorker>(
server_settings[ServerSetting::memory_worker_period_ms],
server_settings[ServerSetting::memory_worker_correct_memory_tracker],
/* use_cgroup */ true,
nullptr);
memory_worker->start();
#endif

getIOThreadPool().initialize(
server_settings[ServerSetting::max_io_thread_pool_size],
server_settings[ServerSetting::max_io_thread_pool_free_size],
Expand Down
3 changes: 3 additions & 0 deletions programs/local/LocalServer.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <Loggers/Loggers.h>
#include <Common/InterruptListener.h>
#include <Common/StatusFile.h>
#include <Common/MemoryWorker.h>

#include <filesystem>
#include <memory>
Expand Down Expand Up @@ -91,6 +92,8 @@ class LocalServer : public ClientApplicationBase, public Loggers

private:
void cleanStreamingQuery();

std::unique_ptr<MemoryWorker> memory_worker;
};

}
70 changes: 64 additions & 6 deletions programs/local/chdb.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,41 @@
#include "PythonTableCache.h"
#endif

extern thread_local bool chdb_destructor_cleanup_in_progress;

namespace CHDB
{

/**
* RAII guard for accurate memory tracking in chDB external interfaces
*
* When Python (or other programming language) threads call chDB-provided interfaces
* such as chdb_destroy_query_result, the memory released cannot be accurately tracked
* by ClickHouse's MemoryTracker, which may lead to false reports of insufficient memory.
*
* Therefore, for all externally exposed chDB interfaces, ChdbDestructorGuard must be
* used at the beginning of execution to provide thread marking, enabling MemoryTracker
* to accurately track memory changes.
*/
class ChdbDestructorGuard
{
public:
ChdbDestructorGuard()
{
chdb_destructor_cleanup_in_progress = true;
}

~ChdbDestructorGuard()
{
chdb_destructor_cleanup_in_progress = false;
}

ChdbDestructorGuard(const ChdbDestructorGuard &) = delete;
ChdbDestructorGuard & operator=(const ChdbDestructorGuard &) = delete;
ChdbDestructorGuard(ChdbDestructorGuard &&) = delete;
ChdbDestructorGuard & operator=(ChdbDestructorGuard &&) = delete;
};

static std::shared_mutex global_connection_mutex;
static std::mutex CHDB_MUTEX;
chdb_conn * global_conn_ptr = nullptr;
Expand Down Expand Up @@ -222,9 +254,6 @@ static std::pair<QueryResultPtr, bool> createQueryResult(DB::LocalServer * serve
else if (!req.isIteration())
{
server->streaming_query_context = std::make_shared<DB::StreamingQueryContext>();
/// TODO: support memory tracker for streaming query
server->streaming_query_context->limit = total_memory_tracker.getHardLimit();
total_memory_tracker.setHardLimit(0);
query_result = createStreamingQueryResult(server, req);
is_end = !query_result->getError().empty();

Expand All @@ -246,9 +275,6 @@ static std::pair<QueryResultPtr, bool> createQueryResult(DB::LocalServer * serve
{
if (server->streaming_query_context)
{
total_memory_tracker.resetCounters();
MemoryTracker::updateRSS(0);
total_memory_tracker.setHardLimit(server->streaming_query_context->limit);
server->streaming_query_context.reset();
}
#if USE_PYTHON
Expand Down Expand Up @@ -429,6 +455,8 @@ using namespace CHDB;

local_result * query_stable(int argc, char ** argv)
{
ChdbDestructorGuard guard;

auto query_result = pyEntryClickHouseLocal(argc, argv);
if (!query_result->getError().empty() || query_result->result_buffer == nullptr)
return nullptr;
Expand All @@ -445,6 +473,8 @@ local_result * query_stable(int argc, char ** argv)

void free_result(local_result * result)
{
ChdbDestructorGuard guard;

if (!result)
{
return;
Expand All @@ -460,6 +490,8 @@ void free_result(local_result * result)

local_result_v2 * query_stable_v2(int argc, char ** argv)
{
ChdbDestructorGuard guard;

// pyEntryClickHouseLocal may throw some serious exceptions, although it's not likely
// to happen in the context of clickhouse-local. we catch them here and return an error
local_result_v2 * res = nullptr;
Expand Down Expand Up @@ -489,6 +521,8 @@ local_result_v2 * query_stable_v2(int argc, char ** argv)

void free_result_v2(local_result_v2 * result)
{
ChdbDestructorGuard guard;

if (!result)
return;

Expand Down Expand Up @@ -693,6 +727,8 @@ void close_conn(chdb_conn ** conn)

struct local_result_v2 * query_conn(chdb_conn * conn, const char * query, const char * format)
{
ChdbDestructorGuard guard;

// Add connection validity check under global lock
std::shared_lock<std::shared_mutex> global_lock(global_connection_mutex);

Expand All @@ -707,6 +743,8 @@ struct local_result_v2 * query_conn(chdb_conn * conn, const char * query, const

chdb_streaming_result * query_conn_streaming(chdb_conn * conn, const char * query, const char * format)
{
ChdbDestructorGuard guard;

// Add connection validity check under global lock
std::shared_lock<std::shared_mutex> global_lock(global_connection_mutex);

Expand Down Expand Up @@ -744,6 +782,8 @@ const char * chdb_streaming_result_error(chdb_streaming_result * result)

local_result_v2 * chdb_streaming_fetch_result(chdb_conn * conn, chdb_streaming_result * result)
{
ChdbDestructorGuard guard;

// Add connection validity check under global lock
std::shared_lock<std::shared_mutex> global_lock(global_connection_mutex);

Expand All @@ -758,6 +798,8 @@ local_result_v2 * chdb_streaming_fetch_result(chdb_conn * conn, chdb_streaming_r

void chdb_streaming_cancel_query(chdb_conn * conn, chdb_streaming_result * result)
{
ChdbDestructorGuard guard;

// Add connection validity check under global lock
std::shared_lock<std::shared_mutex> global_lock(global_connection_mutex);

Expand All @@ -766,15 +808,19 @@ void chdb_streaming_cancel_query(chdb_conn * conn, chdb_streaming_result * resul

auto * queue = static_cast<CHDB::QueryQueue *>(conn->queue);
auto query_result = executeQueryRequest(queue, nullptr, nullptr, CHDB::QueryType::TYPE_STREAMING_ITER, result, true);

query_result.reset();
}

void chdb_destroy_result(chdb_streaming_result * result)
{
ChdbDestructorGuard guard;

if (!result)
return;

auto stream_query_result = reinterpret_cast<StreamQueryResult *>(result);

delete stream_query_result;
}

Expand All @@ -799,6 +845,8 @@ void chdb_close_conn(chdb_connection * conn)

chdb_result * chdb_query(chdb_connection conn, const char * query, const char * format)
{
ChdbDestructorGuard guard;

std::shared_lock<std::shared_mutex> global_lock(global_connection_mutex);

if (!conn)
Expand All @@ -823,6 +871,8 @@ chdb_result * chdb_query(chdb_connection conn, const char * query, const char *

chdb_result * chdb_query_cmdline(int argc, char ** argv)
{
ChdbDestructorGuard guard;

MaterializedQueryResult * result = nullptr;
try
{
Expand All @@ -844,6 +894,8 @@ chdb_result * chdb_query_cmdline(int argc, char ** argv)

chdb_result * chdb_stream_query(chdb_connection conn, const char * query, const char * format)
{
ChdbDestructorGuard guard;

std::shared_lock<std::shared_mutex> global_lock(global_connection_mutex);

if (!conn)
Expand Down Expand Up @@ -873,6 +925,8 @@ chdb_result * chdb_stream_query(chdb_connection conn, const char * query, const

chdb_result * chdb_stream_fetch_result(chdb_connection conn, chdb_result * result)
{
ChdbDestructorGuard guard;

std::shared_lock<std::shared_mutex> global_lock(global_connection_mutex);

if (!conn)
Expand Down Expand Up @@ -903,6 +957,8 @@ chdb_result * chdb_stream_fetch_result(chdb_connection conn, chdb_result * resul

void chdb_stream_cancel_query(chdb_connection conn, chdb_result * result)
{
ChdbDestructorGuard guard;

std::shared_lock<std::shared_mutex> global_lock(global_connection_mutex);

if (!result || !conn)
Expand All @@ -919,6 +975,8 @@ void chdb_stream_cancel_query(chdb_connection conn, chdb_result * result)

void chdb_destroy_query_result(chdb_result * result)
{
ChdbDestructorGuard guard;

if (!result)
return;

Expand Down
1 change: 0 additions & 1 deletion src/Client/ClientBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ struct StreamingQueryContext
ASTPtr parsed_query;
void * streaming_result = nullptr;
bool is_streaming_query = true;
Int64 limit = 0;

StreamingQueryContext() = default;
};
Expand Down
28 changes: 19 additions & 9 deletions src/Common/CurrentMemoryTracker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
thread_local bool memory_tracker_always_throw_logical_error_on_allocation = false;
#endif

thread_local bool chdb_destructor_cleanup_in_progress = false;

namespace DB
{
namespace ErrorCodes
Expand All @@ -31,6 +33,9 @@ MemoryTracker * getMemoryTracker()
if (DB::MainThreadStatus::get())
return &total_memory_tracker;

if (chdb_destructor_cleanup_in_progress)
return &total_memory_tracker;

return nullptr;
}

Expand Down Expand Up @@ -70,18 +75,23 @@ AllocationTrace CurrentMemoryTracker::allocImpl(Int64 size, bool throw_if_memory
}
else
{
Int64 will_be = current_thread->untracked_memory + size;

if (will_be > current_thread->untracked_memory_limit)
Int64 previous_untracked_memory = current_thread->untracked_memory;
current_thread->untracked_memory += size;
if (current_thread->untracked_memory > current_thread->untracked_memory_limit)
{
auto res = memory_tracker->allocImpl(will_be, throw_if_memory_exceeded);
Int64 current_untracked_memory = current_thread->untracked_memory;
current_thread->untracked_memory = 0;
return res;
}

/// Update after successful allocations,
/// since failed allocations should not be take into account.
current_thread->untracked_memory = will_be;
try
{
return memory_tracker->allocImpl(current_untracked_memory, throw_if_memory_exceeded);
}
catch (...)
{
current_thread->untracked_memory += previous_untracked_memory;
throw;
}
}
}

return AllocationTrace(memory_tracker->getSampleProbability(size));
Expand Down
4 changes: 2 additions & 2 deletions src/Common/MemoryWorker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ MemoryWorker::MemoryWorker(uint64_t period_ms_, bool correct_tracker_, bool use_
static constexpr uint64_t cgroups_memory_usage_tick_ms{50};

const auto [cgroup_path, version] = getCgroupsPath();
LOG_INFO(
LOG_TRACE(
getLogger("CgroupsReader"),
"Will create cgroup reader from '{}' (cgroups version: {})",
cgroup_path,
Expand Down Expand Up @@ -275,7 +275,7 @@ void MemoryWorker::start()
if (source == MemoryUsageSource::None)
return;

LOG_INFO(
LOG_TRACE(
getLogger("MemoryWorker"),
"Starting background memory thread with period of {}ms, using {} as source",
period_ms,
Expand Down
3 changes: 2 additions & 1 deletion src/Common/ThreadStatus.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,9 @@ void ThreadStatus::flushUntrackedMemory()
if (untracked_memory == 0)
return;

memory_tracker.adjustWithUntrackedMemory(untracked_memory);
Int64 current_untracked_memory = current_thread->untracked_memory;
untracked_memory = 0;
memory_tracker.adjustWithUntrackedMemory(current_untracked_memory);
}

bool ThreadStatus::isQueryCanceled() const
Expand Down
4 changes: 4 additions & 0 deletions src/Core/BackgroundSchedulePool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,8 @@ void BackgroundSchedulePool::threadFunction()
{
TaskInfoPtr task;

current_thread->flushUntrackedMemory();

{
UniqueLock tasks_lock(tasks_mutex);

Expand All @@ -306,6 +308,8 @@ void BackgroundSchedulePool::threadFunction()

if (task)
task->execute();

current_thread->flushUntrackedMemory();
}
}

Expand Down
4 changes: 4 additions & 0 deletions src/Storages/MergeTree/MergeTreeBackgroundExecutor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,8 @@ void MergeTreeBackgroundExecutor<Queue>::threadFunction()
{
setThreadName(name.c_str());

current_thread->flushUntrackedMemory();

DENY_ALLOCATIONS_IN_SCOPE;

while (true)
Expand Down Expand Up @@ -373,6 +375,8 @@ void MergeTreeBackgroundExecutor<Queue>::threadFunction()
tryLogCurrentException(__PRETTY_FUNCTION__);
});
}

current_thread->flushUntrackedMemory();
}
}

Expand Down
Loading