Skip to content

Commit

Permalink
feat: check if agent is registered or key is valid before launching c…
Browse files Browse the repository at this point in the history
…oroutines
  • Loading branch information
TomasTurina committed Dec 2, 2024
1 parent 5d629fe commit bef68dc
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 11 deletions.
8 changes: 4 additions & 4 deletions src/agent/communicator/include/communicator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ namespace communicator
.value_or(config::agent::DEFAULT_RETRY_INTERVAL);
}

/// @brief Sends an authentication request to the manager
/// @return The HTTP status of the authentication request
boost::beast::http::status SendAuthenticationRequest();

/// @brief Waits for the authentication token to expire and authenticates again
boost::asio::awaitable<void> WaitForTokenExpirationAndAuthenticate();

Expand Down Expand Up @@ -94,10 +98,6 @@ namespace communicator
/// @return The remaining time in seconds until the authentication token expires
long GetTokenRemainingSecs() const;

/// @brief Sends an authentication request to the manager
/// @return The HTTP status of the authentication request
boost::beast::http::status SendAuthenticationRequest();

/// @brief Checks if the authentication token has expired and authenticates again if necessary
void TryReAuthenticate();

Expand Down
26 changes: 20 additions & 6 deletions src/agent/communicator/src/communicator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,19 +89,33 @@ namespace communicator
const auto executor = co_await boost::asio::this_coro::executor;
m_tokenExpTimer = std::make_unique<boost::asio::steady_timer>(executor);

if (auto remainingSecs = GetTokenRemainingSecs(); remainingSecs > TOKEN_PRE_EXPIRY_SECS)
{
m_tokenExpTimer->expires_after(
std::chrono::milliseconds((remainingSecs - TOKEN_PRE_EXPIRY_SECS) * A_SECOND_IN_MILLIS));
co_await m_tokenExpTimer->async_wait(boost::asio::use_awaitable);
}

while (m_keepRunning.load())
{
const auto duration = [this]()
{
const auto result = SendAuthenticationRequest();
if (result != boost::beast::http::status::ok)
try
{
return std::chrono::milliseconds(m_retryInterval);
const auto result = SendAuthenticationRequest();
if (result != boost::beast::http::status::ok)
{
return std::chrono::milliseconds(m_retryInterval);
}
else
{
return std::chrono::milliseconds((GetTokenRemainingSecs() - TOKEN_PRE_EXPIRY_SECS) *
A_SECOND_IN_MILLIS);
}
}
else
catch (const std::exception&)
{
return std::chrono::milliseconds((GetTokenRemainingSecs() - TOKEN_PRE_EXPIRY_SECS) *
A_SECOND_IN_MILLIS);
return std::chrono::milliseconds(m_retryInterval);
}
}();

Expand Down
22 changes: 22 additions & 0 deletions src/agent/communicator/src/http_client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <nlohmann/json.hpp>

#include <chrono>
#include <iostream>
#include <sstream>
#include <string>

Expand Down Expand Up @@ -241,6 +242,27 @@ namespace http_client

if (res.result() != boost::beast::http::status::ok)
{
if (res.result() == boost::beast::http::status::unauthorized ||
res.result() == boost::beast::http::status::forbidden)
{
std::string message {};

try
{
message = nlohmann::json::parse(boost::beast::buffers_to_string(res.body().data()))
.at("message")
.get_ref<const std::string&>();
}
catch (const std::exception& e)
{
LogError("Error parsing message in response: {}.", e.what());
}

if (message == "Invalid key" || message == "Agent does not exist")
{
throw std::runtime_error(message);
}
}
LogWarn("Error: {}.", res.result_int());
return std::nullopt;
}
Expand Down
27 changes: 26 additions & 1 deletion src/agent/communicator/tests/http_client_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -478,14 +478,39 @@ TEST_F(HttpClientTest, AuthenticateWithUuidAndKey_Failure)
EXPECT_CALL(*mockSocket, Connect(_, _)).Times(1);
EXPECT_CALL(*mockSocket, Write(_, _)).Times(1);
EXPECT_CALL(*mockSocket, Read(_, _))
.WillOnce([](auto& res, auto&) { res.result(boost::beast::http::status::unauthorized); });
.WillOnce(
[](auto& res, auto&)
{
res.result(boost::beast::http::status::unauthorized);
boost::beast::ostream(res.body()) << R"({"message":"Try again"})";
});

const auto token =
client->AuthenticateWithUuidAndKey("https://localhost:8080", "Wazuh 5.0.0", "test-uuid", "test-key");

EXPECT_FALSE(token.has_value());
}

TEST_F(HttpClientTest, AuthenticateWithUuidAndKey_FailureThrowsException)
{
SetupMockResolverFactory();
SetupMockSocketFactory();

EXPECT_CALL(*mockResolver, Resolve(_, _)).WillOnce(Return(dummyResults));
EXPECT_CALL(*mockSocket, Connect(_, _)).Times(1);
EXPECT_CALL(*mockSocket, Write(_, _)).Times(1);
EXPECT_CALL(*mockSocket, Read(_, _))
.WillOnce(
[](auto& res, auto&)
{
res.result(boost::beast::http::status::unauthorized);
boost::beast::ostream(res.body()) << R"({"message":"Invalid key"})";
});

EXPECT_THROW(client->AuthenticateWithUuidAndKey("https://localhost:8080", "Wazuh 5.0.0", "test-uuid", "test-key"),
std::runtime_error);
}

TEST_F(HttpClientTest, AuthenticateWithUserPassword_Success)
{
SetupMockResolverFactory();
Expand Down
3 changes: 3 additions & 0 deletions src/agent/src/agent.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ Agent::~Agent()

void Agent::Run()
{
// Check if agent is registered
m_communicator.SendAuthenticationRequest();

m_taskManager.EnqueueTask(m_communicator.WaitForTokenExpirationAndAuthenticate());

m_taskManager.EnqueueTask(m_communicator.GetCommandsFromManager(
Expand Down

0 comments on commit bef68dc

Please sign in to comment.