Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix race condition associated to module reload and modulemanager #434

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/agent/include/agent.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@

#include <sysInfo.hpp>

#include <atomic>
#include <memory>
#include <mutex>
#include <string>

/// @brief Agent class
Expand Down Expand Up @@ -76,4 +78,13 @@ class Agent

/// @brief Centralized configuration
centralized_configuration::CentralizedConfiguration m_centralizedConfiguration;

/// @brief Mutex to coordinate agent reload
std::mutex m_reloadMutex;

/// @brief Indicates if the agent is running
std::atomic<bool> m_running = true;

/// @brief Agent thread count
size_t m_agentThreadCount;
};
47 changes: 37 additions & 10 deletions src/agent/src/agent.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,16 +58,14 @@ Agent::Agent(const std::string& configFilePath, std::unique_ptr<ISignalHandler>

m_centralizedConfiguration.ReloadModulesFunction([this]() { ReloadModules(); });

auto agentThreadCount =
m_agentThreadCount =
m_configurationParser->GetConfig<size_t>("agent", "thread_count").value_or(config::DEFAULT_THREAD_COUNT);

if (agentThreadCount < config::DEFAULT_THREAD_COUNT)
if (m_agentThreadCount < config::DEFAULT_THREAD_COUNT)
{
LogWarn("thread_count must be greater than {}. Using default value.", config::DEFAULT_THREAD_COUNT);
agentThreadCount = config::DEFAULT_THREAD_COUNT;
m_agentThreadCount = config::DEFAULT_THREAD_COUNT;
}

m_taskManager.Start(agentThreadCount);
}

Agent::~Agent()
Expand All @@ -77,15 +75,34 @@ Agent::~Agent()

void Agent::ReloadModules()
{
LogInfo("Reloading Modules");
m_configurationParser->ReloadConfiguration();
m_moduleManager.Stop();
m_moduleManager.Setup();
m_moduleManager.Start();
std::lock_guard<std::mutex> lock(m_reloadMutex);

if (m_running.load())
{
try
{
LogInfo("Reloading Modules");
m_configurationParser->ReloadConfiguration();
m_moduleManager.Stop();
m_moduleManager.Setup();
m_moduleManager.Start();
LogInfo("Modules reloaded");
}
catch (const std::exception& e)
{
LogError("Error reloading modules: {}", e.what());
}
}
else
{
LogWarn("Agent cannot reload modules while start up or shutdown is in progress.");
}
}

void Agent::Run()
{
m_taskManager.Start(m_agentThreadCount);

// Check if the server recognizes the agent
m_communicator.SendAuthenticationRequest();

Expand Down Expand Up @@ -144,8 +161,18 @@ void Agent::Run()
}),
"CommandsProcessing");

{
std::unique_lock<std::mutex> lock(m_reloadMutex);
m_running.store(true);
}

m_signalHandler->WaitForSignal();

{
std::unique_lock<std::mutex> lock(m_reloadMutex);
m_running.store(false);
}

m_commandHandler.Stop();
m_communicator.Stop();
m_moduleManager.Stop();
Expand Down
1 change: 1 addition & 0 deletions src/agent/src/signal_handler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ void SignalHandler::HandleSignal([[maybe_unused]] int signal)

void SignalHandler::WaitForSignal()
{
KeepRunning.store(true);
std::unique_lock<std::mutex> lock(m_cvMutex);
m_cv.wait(lock, [] { return !KeepRunning.load(); });
}
4 changes: 4 additions & 0 deletions src/agent/task_manager/include/task_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

#include <functional>
#include <memory>
#include <mutex>
#include <string>
#include <thread>
#include <vector>
Expand Down Expand Up @@ -53,4 +54,7 @@ class TaskManager : public ITaskManager<boost::asio::awaitable<void>>

/// @brief Number of enqueued threads
size_t m_numEnqueuedThreads = 0;

/// @brief Mutex to control Start and Stop operations
mutable std::mutex m_mutex;
};
11 changes: 10 additions & 1 deletion src/agent/task_manager/src/task_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ TaskManager::~TaskManager()

void TaskManager::Start(size_t numThreads)
{
std::lock_guard<std::mutex> lock(m_mutex);

if (m_work || !m_threads.empty())
{
LogError("Task manager already started");
Expand All @@ -32,9 +34,11 @@ void TaskManager::Start(size_t numThreads)

void TaskManager::Stop()
{
std::lock_guard<std::mutex> lock(m_mutex);

if (m_work)
{
m_work->reset();
m_work.reset();
}

if (!m_ioContext.stopped())
Expand All @@ -59,6 +63,8 @@ void TaskManager::Stop()

void TaskManager::EnqueueTask(std::function<void()> task, const std::string& taskID)
{
std::lock_guard<std::mutex> lock(m_mutex);

if (++m_numEnqueuedThreads > m_threads.size())
{
LogError("Enqueued more threaded tasks than available threads");
Expand All @@ -82,6 +88,8 @@ void TaskManager::EnqueueTask(std::function<void()> task, const std::string& tas

void TaskManager::EnqueueTask(boost::asio::awaitable<void> task, const std::string& taskID)
{
std::lock_guard<std::mutex> lock(m_mutex);

// NOLINTBEGIN(cppcoreguidelines-avoid-capturing-lambda-coroutines)
boost::asio::co_spawn(
m_ioContext,
Expand All @@ -105,5 +113,6 @@ void TaskManager::EnqueueTask(boost::asio::awaitable<void> task, const std::stri

size_t TaskManager::GetNumEnqueuedThreads() const
{
std::lock_guard<std::mutex> lock(m_mutex);
return m_numEnqueuedThreads;
}
22 changes: 18 additions & 4 deletions src/modules/include/moduleManager.hpp
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
#pragma once

#include <map>
#include <memory>
#include <string>
#include <thread>
#include <multitype_queue.hpp>
#include <moduleWrapper.hpp>
#include <task_manager.hpp>

#include <boost/asio/awaitable.hpp>

#include <map>
#include <memory>
#include <mutex>
#include <string>
#include <thread>


class ModuleManager {
public:
ModuleManager(const std::function<int(Message)>& pushMessage, std::shared_ptr<configuration::ConfigurationParser> configurationParser, std::string uuid)
Expand Down Expand Up @@ -46,6 +49,15 @@ class ModuleManager {

void AddModules();
std::shared_ptr<ModuleWrapper> GetModule(const std::string & name);

/// @brief Start the modules
///
/// This function begins the procedure to start the modules and blocks until the Start function
/// for each module has been called. However, it does not guarantee that the modules are fully
/// operational upon return; they may still be in the process of initializing.
///
/// @note Call this function before interacting with the modules to ensure the startup process is initiated.
/// @warning Ensure the modules have fully started before performing any operations that depend on them.
void Start();
void Setup();
void Stop();
Expand All @@ -56,4 +68,6 @@ class ModuleManager {
std::function<int(Message)> m_pushMessage;
std::shared_ptr<configuration::ConfigurationParser> m_configurationParser;
std::string m_agentUUID;
std::mutex m_mutex;
std::atomic<int> m_started {0};
};
49 changes: 40 additions & 9 deletions src/modules/src/moduleManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,31 +24,62 @@ void ModuleManager::AddModules() {
Setup();
}

std::shared_ptr<ModuleWrapper> ModuleManager::GetModule(const std::string & name) {
auto it = m_modules.find(name);
if (it != m_modules.end()) {
std::shared_ptr<ModuleWrapper> ModuleManager::GetModule(const std::string & name)
{
if (auto it = m_modules.find(name); it != m_modules.end())
{
return it->second;
}
return nullptr;
}

void ModuleManager::Start() {
void ModuleManager::Start()
{
std::unique_lock<std::mutex> lock(m_mutex);

m_taskManager.Start(m_modules.size());
std::condition_variable cv;

for (const auto &[_, module] : m_modules)
{
m_taskManager.EnqueueTask([module]() { module->Start(); }, module->Name());
m_taskManager.EnqueueTask(
[module, this, &cv]
{
++m_started;
cv.notify_one();
module->Start();
}
, module->Name()
);
}

cv.wait(
lock,
[this]
{
return m_started.load() == static_cast<int>(m_modules.size());
}
);
}

void ModuleManager::Setup() {
for (const auto &[_, module] : m_modules) {
void ModuleManager::Setup()
{
std::lock_guard<std::mutex> lock(m_mutex);

for (const auto &[_, module] : m_modules)
{
module->Setup(m_configurationParser);
}
}

void ModuleManager::Stop() {
for (const auto &[_, module] : m_modules) {
void ModuleManager::Stop()
{
std::lock_guard<std::mutex> lock(m_mutex);

for (const auto &[_, module] : m_modules)
{
module->Stop();
m_started--;
}
m_taskManager.Stop();
}
Loading