diff --git a/libamqpprox/amqpprox_connectionlimitermanager.cpp b/libamqpprox/amqpprox_connectionlimitermanager.cpp index e689531..717dad0 100644 --- a/libamqpprox/amqpprox_connectionlimitermanager.cpp +++ b/libamqpprox/amqpprox_connectionlimitermanager.cpp @@ -21,6 +21,7 @@ #include #include +#include #include namespace Bloomberg { @@ -29,7 +30,7 @@ namespace amqpprox { namespace { void maybePopulateDefaultLimiters( const std::string &vhostName, - uint32_t defaultLimit, + std::optional defaultLimit, ConnectionLimiterManager::ConnectionLimiters &limitersPerVhost) { if (limitersPerVhost.find(vhostName) == limitersPerVhost.end()) { @@ -37,7 +38,7 @@ void maybePopulateDefaultLimiters( limitersPerVhost[vhostName] = { false, std::make_shared( - defaultLimit)}; + *defaultLimit)}; } } } @@ -46,8 +47,8 @@ void maybePopulateDefaultLimiters( ConnectionLimiterManager::ConnectionLimiterManager() : d_connectionRateLimitersPerVhost() , d_alarmOnlyConnectionRateLimitersPerVhost() -, d_defaultConnectionRateLimit(0) -, d_defaultAlarmOnlyConnectionRateLimit(0) +, d_defaultConnectionRateLimit() +, d_defaultAlarmOnlyConnectionRateLimit() , d_mutex() { } @@ -95,7 +96,7 @@ void ConnectionLimiterManager::setDefaultConnectionRateLimit( if (!limiter.second.first) { limiter.second.second = std::make_shared( - d_defaultConnectionRateLimit); + *d_defaultConnectionRateLimit); } } } @@ -113,7 +114,7 @@ void ConnectionLimiterManager::setAlarmOnlyDefaultConnectionRateLimit( if (!limiter.second.first) { limiter.second.second = std::make_shared( - d_defaultAlarmOnlyConnectionRateLimit); + *d_defaultAlarmOnlyConnectionRateLimit); } } } @@ -127,7 +128,7 @@ void ConnectionLimiterManager::removeConnectionRateLimiter( d_connectionRateLimitersPerVhost[vhostName] = { false, std::make_shared( - d_defaultConnectionRateLimit)}; + *d_defaultConnectionRateLimit)}; } else { d_connectionRateLimitersPerVhost.erase(vhostName); @@ -143,7 +144,7 @@ void ConnectionLimiterManager::removeAlarmOnlyConnectionRateLimiter( d_alarmOnlyConnectionRateLimitersPerVhost[vhostName] = { false, std::make_shared( - d_defaultAlarmOnlyConnectionRateLimit)}; + *d_defaultAlarmOnlyConnectionRateLimit)}; } else { d_alarmOnlyConnectionRateLimitersPerVhost.erase(vhostName); @@ -154,7 +155,7 @@ void ConnectionLimiterManager::removeDefaultConnectionRateLimit() { std::lock_guard lg(d_mutex); - d_defaultConnectionRateLimit = 0; + d_defaultConnectionRateLimit.reset(); for (auto it = d_connectionRateLimitersPerVhost.cbegin(); it != d_connectionRateLimitersPerVhost.cend();) { if (!it->second.first) { @@ -170,7 +171,7 @@ void ConnectionLimiterManager::removeAlarmOnlyDefaultConnectionRateLimit() { std::lock_guard lg(d_mutex); - d_defaultAlarmOnlyConnectionRateLimit = 0; + d_defaultAlarmOnlyConnectionRateLimit.reset(); for (auto it = d_alarmOnlyConnectionRateLimitersPerVhost.cbegin(); it != d_alarmOnlyConnectionRateLimitersPerVhost.cend();) { if (!it->second.first) { @@ -263,12 +264,13 @@ ConnectionLimiterManager::getAlarmOnlyConnectionRateLimiter( return nullptr; } -uint32_t ConnectionLimiterManager::getDefaultConnectionRateLimit() const +std::optional +ConnectionLimiterManager::getDefaultConnectionRateLimit() const { return d_defaultConnectionRateLimit; } -uint32_t +std::optional ConnectionLimiterManager::getAlarmOnlyDefaultConnectionRateLimit() const { return d_defaultAlarmOnlyConnectionRateLimit; diff --git a/libamqpprox/amqpprox_connectionlimitermanager.h b/libamqpprox/amqpprox_connectionlimitermanager.h index 8d7a39d..003a3cf 100644 --- a/libamqpprox/amqpprox_connectionlimitermanager.h +++ b/libamqpprox/amqpprox_connectionlimitermanager.h @@ -18,9 +18,9 @@ #include -#include #include #include +#include #include #include #include @@ -58,9 +58,9 @@ class ConnectionLimiterManager { ConnectionLimiters d_connectionRateLimitersPerVhost; ConnectionLimiters d_alarmOnlyConnectionRateLimitersPerVhost; - uint32_t d_defaultConnectionRateLimit; - uint32_t d_defaultAlarmOnlyConnectionRateLimit; - mutable std::mutex d_mutex; + std::optional d_defaultConnectionRateLimit; + std::optional d_defaultAlarmOnlyConnectionRateLimit; + mutable std::mutex d_mutex; public: // CREATORS @@ -142,6 +142,11 @@ class ConnectionLimiterManager { */ bool allowNewConnectionForVhost(const std::string &vhostName); + /** + * \brief Called when a session is marked as disconnected. + */ + void sessionClosedForVhost(const std::string &vhostName); + // ACCESSORS /** * \brief Get particular connection rate limiter based on specified vhost @@ -162,13 +167,13 @@ class ConnectionLimiterManager { * \brief Get default connection rate limit (allowed connections per * second) for all the connecting vhosts */ - uint32_t getDefaultConnectionRateLimit() const; + std::optional getDefaultConnectionRateLimit() const; /** - * \brief Get alarm onlt default connection rate limit (allowed connections + * \brief Get alarm only default connection rate limit (allowed connections * per second) for all the connecting vhosts */ - uint32_t getAlarmOnlyDefaultConnectionRateLimit() const; + std::optional getAlarmOnlyDefaultConnectionRateLimit() const; }; } diff --git a/libamqpprox/amqpprox_fixedwindowconnectionratelimiter.cpp b/libamqpprox/amqpprox_fixedwindowconnectionratelimiter.cpp index 240f4e2..e6eb979 100644 --- a/libamqpprox/amqpprox_fixedwindowconnectionratelimiter.cpp +++ b/libamqpprox/amqpprox_fixedwindowconnectionratelimiter.cpp @@ -20,7 +20,6 @@ #include #include -#include #include namespace Bloomberg { diff --git a/libamqpprox/amqpprox_fixedwindowconnectionratelimiter.h b/libamqpprox/amqpprox_fixedwindowconnectionratelimiter.h index 7744819..a524425 100644 --- a/libamqpprox/amqpprox_fixedwindowconnectionratelimiter.h +++ b/libamqpprox/amqpprox_fixedwindowconnectionratelimiter.h @@ -20,7 +20,6 @@ #include #include -#include #include namespace Bloomberg { @@ -47,8 +46,8 @@ struct LimiterClock { * provided connection limit and time window. The connection rate limit will be * connection limit/timeWindow (average allowed connections in the specified * time window). allowNewConnection member function will return true or false - * based on the rate limit calculation. Implements the LimiterInterface - * interface + * based on the rate limit calculation. Implements the + * ConnectionLimiterInterface interface */ class FixedWindowConnectionRateLimiter : public ConnectionLimiterInterface { protected: @@ -97,7 +96,7 @@ class FixedWindowConnectionRateLimiter : public ConnectionLimiterInterface { // ACCESSORS /** - * \return Information about limiter as a string + * \return Information about connection limiter as a string */ virtual std::string toString() const override; diff --git a/libamqpprox/amqpprox_limitcontrolcommand.cpp b/libamqpprox/amqpprox_limitcontrolcommand.cpp index 5d30b1c..7cbba98 100644 --- a/libamqpprox/amqpprox_limitcontrolcommand.cpp +++ b/libamqpprox/amqpprox_limitcontrolcommand.cpp @@ -67,6 +67,7 @@ void handleConnectionLimitAlarm( output << "Default connection rate limit is set to " << connectionLimiterManager ->getAlarmOnlyDefaultConnectionRateLimit() + .value() << " connections per second in alarm only mode.\n"; output << "The limiter will only log at warning level with " "AMQPPROX_CONNECTION_LIMIT as a substring and the " @@ -122,6 +123,7 @@ void handleConnectionLimit( numberOfConnections); output << "Default connection rate limit is set to " << connectionLimiterManager->getDefaultConnectionRateLimit() + .value() << " connections per second.\n"; } else { @@ -155,19 +157,19 @@ void printVhostLimits( } if (!alarmLimiter && !limiter) { - uint32_t alarmOnlyConnectionRateLimit = + std::optional alarmOnlyConnectionRateLimit = connectionLimiterManager->getAlarmOnlyDefaultConnectionRateLimit(); - uint32_t connectionRateLimit = + std::optional connectionRateLimit = connectionLimiterManager->getDefaultConnectionRateLimit(); if (alarmOnlyConnectionRateLimit || connectionRateLimit) { if (alarmOnlyConnectionRateLimit) { output << "Alarm only limit, for vhost " << vhostName - << ", allow average " << alarmOnlyConnectionRateLimit + << ", allow average " << *alarmOnlyConnectionRateLimit << " number of connections per second.\n"; } if (connectionRateLimit) { output << "For vhost " << vhostName << ", allow average " - << connectionRateLimit + << *connectionRateLimit << " number of connections per second.\n"; } } @@ -182,19 +184,19 @@ void printAllLimits( ConnectionLimiterManager *connectionLimiterManager, ControlCommandOutput &output) { - uint32_t alarmOnlyConnectionRateLimit = + std::optional alarmOnlyConnectionRateLimit = connectionLimiterManager->getAlarmOnlyDefaultConnectionRateLimit(); - uint32_t connectionRateLimit = + std::optional connectionRateLimit = connectionLimiterManager->getDefaultConnectionRateLimit(); if (alarmOnlyConnectionRateLimit || connectionRateLimit) { if (alarmOnlyConnectionRateLimit) { output << "Default limit for any vhost, allow average " - << alarmOnlyConnectionRateLimit + << *alarmOnlyConnectionRateLimit << " connections per second in alarm only mode.\n"; } if (connectionRateLimit) { output << "Default limit for any vhost, allow average " - << connectionRateLimit << " connections per second.\n"; + << *connectionRateLimit << " connections per second.\n"; } } else { diff --git a/tests/amqpprox_connectionlimitermanager.t.cpp b/tests/amqpprox_connectionlimitermanager.t.cpp index faf3d14..78e140f 100644 --- a/tests/amqpprox_connectionlimitermanager.t.cpp +++ b/tests/amqpprox_connectionlimitermanager.t.cpp @@ -31,8 +31,8 @@ using namespace testing; TEST(ConnectionLimiterManagerTest, Breathing) { ConnectionLimiterManager limiterManager; - EXPECT_EQ(limiterManager.getAlarmOnlyDefaultConnectionRateLimit(), 0); - EXPECT_EQ(limiterManager.getDefaultConnectionRateLimit(), 0); + EXPECT_FALSE(limiterManager.getAlarmOnlyDefaultConnectionRateLimit()); + EXPECT_FALSE(limiterManager.getDefaultConnectionRateLimit()); EXPECT_TRUE(limiterManager.getAlarmOnlyConnectionRateLimiter( "test-vhost") == nullptr); EXPECT_TRUE(limiterManager.getConnectionRateLimiter("test-vhost") == @@ -165,48 +165,52 @@ TEST(ConnectionLimiterManagerTest, AddGetRemoveAlarmOnlyConnectionRateLimiter) TEST(ConnectionLimiterManagerTest, SetGetRemoveDefaultConnectionRateLimiter) { ConnectionLimiterManager limiterManager; - EXPECT_EQ(limiterManager.getAlarmOnlyDefaultConnectionRateLimit(), 0); - EXPECT_EQ(limiterManager.getDefaultConnectionRateLimit(), 0); + EXPECT_FALSE(limiterManager.getAlarmOnlyDefaultConnectionRateLimit()); + EXPECT_FALSE(limiterManager.getDefaultConnectionRateLimit()); uint32_t connectionLimit1 = 100; // Setting default limiter limiterManager.setDefaultConnectionRateLimit(connectionLimit1); // Getting default limiter - EXPECT_EQ(limiterManager.getDefaultConnectionRateLimit(), + ASSERT_TRUE(limiterManager.getDefaultConnectionRateLimit()); + EXPECT_EQ(*limiterManager.getDefaultConnectionRateLimit(), connectionLimit1); - EXPECT_EQ(limiterManager.getAlarmOnlyDefaultConnectionRateLimit(), 0); + EXPECT_FALSE(limiterManager.getAlarmOnlyDefaultConnectionRateLimit()); uint32_t connectionLimit2 = 200; // Setting alarm only default limiter limiterManager.setAlarmOnlyDefaultConnectionRateLimit(connectionLimit2); // Getting alarm only default limiter - EXPECT_EQ(limiterManager.getAlarmOnlyDefaultConnectionRateLimit(), + ASSERT_TRUE(limiterManager.getAlarmOnlyDefaultConnectionRateLimit()); + EXPECT_EQ(*limiterManager.getAlarmOnlyDefaultConnectionRateLimit(), connectionLimit2); - EXPECT_EQ(limiterManager.getDefaultConnectionRateLimit(), + ASSERT_TRUE(limiterManager.getDefaultConnectionRateLimit()); + EXPECT_EQ(*limiterManager.getDefaultConnectionRateLimit(), connectionLimit1); // Removing default limiter limiterManager.removeDefaultConnectionRateLimit(); // Getting default limiter - EXPECT_EQ(limiterManager.getDefaultConnectionRateLimit(), 0); - EXPECT_EQ(limiterManager.getAlarmOnlyDefaultConnectionRateLimit(), + EXPECT_FALSE(limiterManager.getDefaultConnectionRateLimit()); + ASSERT_TRUE(limiterManager.getAlarmOnlyDefaultConnectionRateLimit()); + EXPECT_EQ(*limiterManager.getAlarmOnlyDefaultConnectionRateLimit(), connectionLimit2); // Removing alarm only default limiter limiterManager.removeAlarmOnlyDefaultConnectionRateLimit(); // Getting default limiter - EXPECT_EQ(limiterManager.getDefaultConnectionRateLimit(), 0); - EXPECT_EQ(limiterManager.getAlarmOnlyDefaultConnectionRateLimit(), 0); + EXPECT_FALSE(limiterManager.getDefaultConnectionRateLimit()); + EXPECT_FALSE(limiterManager.getAlarmOnlyDefaultConnectionRateLimit()); } TEST(ConnectionLimiterManagerTest, AllowNewConnectionForVhostWithoutAnyLimit) { ConnectionLimiterManager limiterManager; - EXPECT_EQ(limiterManager.getDefaultConnectionRateLimit(), 0); + EXPECT_FALSE(limiterManager.getDefaultConnectionRateLimit()); EXPECT_TRUE(limiterManager.allowNewConnectionForVhost("test-vhost")); EXPECT_TRUE(limiterManager.allowNewConnectionForVhost("test-vhost")); } @@ -216,13 +220,13 @@ TEST(ConnectionLimiterManagerTest, { ConnectionLimiterManager limiterManager; std::string vhostName = "test-vhost"; - EXPECT_EQ(limiterManager.getDefaultConnectionRateLimit(), 0); + EXPECT_FALSE(limiterManager.getDefaultConnectionRateLimit()); EXPECT_TRUE(limiterManager.allowNewConnectionForVhost(vhostName)); EXPECT_TRUE(limiterManager.allowNewConnectionForVhost(vhostName)); uint32_t connectionLimit = 1; limiterManager.addConnectionRateLimiter(vhostName, connectionLimit); - EXPECT_EQ(limiterManager.getDefaultConnectionRateLimit(), 0); + EXPECT_FALSE(limiterManager.getDefaultConnectionRateLimit()); EXPECT_TRUE(limiterManager.allowNewConnectionForVhost(vhostName)); EXPECT_FALSE(limiterManager.allowNewConnectionForVhost(vhostName)); } @@ -232,14 +236,14 @@ TEST(ConnectionLimiterManagerTest, { ConnectionLimiterManager limiterManager; std::string vhostName = "test-vhost"; - EXPECT_EQ(limiterManager.getAlarmOnlyDefaultConnectionRateLimit(), 0); + EXPECT_FALSE(limiterManager.getAlarmOnlyDefaultConnectionRateLimit()); EXPECT_TRUE(limiterManager.allowNewConnectionForVhost(vhostName)); EXPECT_TRUE(limiterManager.allowNewConnectionForVhost(vhostName)); uint32_t connectionLimit = 1; limiterManager.addAlarmOnlyConnectionRateLimiter(vhostName, connectionLimit); - EXPECT_EQ(limiterManager.getAlarmOnlyDefaultConnectionRateLimit(), 0); + EXPECT_FALSE(limiterManager.getAlarmOnlyDefaultConnectionRateLimit()); EXPECT_TRUE(limiterManager.allowNewConnectionForVhost(vhostName)); EXPECT_TRUE(limiterManager.allowNewConnectionForVhost(vhostName)); } @@ -249,13 +253,15 @@ TEST(ConnectionLimiterManagerTest, { ConnectionLimiterManager limiterManager; std::string vhostName = "test-vhost"; - EXPECT_EQ(limiterManager.getDefaultConnectionRateLimit(), 0); + EXPECT_FALSE(limiterManager.getDefaultConnectionRateLimit()); EXPECT_TRUE(limiterManager.allowNewConnectionForVhost(vhostName)); EXPECT_TRUE(limiterManager.allowNewConnectionForVhost(vhostName)); uint32_t connectionLimit = 1; limiterManager.setDefaultConnectionRateLimit(connectionLimit); - EXPECT_EQ(limiterManager.getDefaultConnectionRateLimit(), connectionLimit); + ASSERT_TRUE(limiterManager.getDefaultConnectionRateLimit()); + EXPECT_EQ(*limiterManager.getDefaultConnectionRateLimit(), + connectionLimit); EXPECT_TRUE(limiterManager.allowNewConnectionForVhost(vhostName)); EXPECT_FALSE(limiterManager.allowNewConnectionForVhost(vhostName)); } @@ -265,13 +271,14 @@ TEST(ConnectionLimiterManagerTest, { ConnectionLimiterManager limiterManager; std::string vhostName = "test-vhost"; - EXPECT_EQ(limiterManager.getAlarmOnlyDefaultConnectionRateLimit(), 0); + EXPECT_FALSE(limiterManager.getAlarmOnlyDefaultConnectionRateLimit()); EXPECT_TRUE(limiterManager.allowNewConnectionForVhost(vhostName)); EXPECT_TRUE(limiterManager.allowNewConnectionForVhost(vhostName)); uint32_t connectionLimit = 1; limiterManager.setAlarmOnlyDefaultConnectionRateLimit(connectionLimit); - EXPECT_EQ(limiterManager.getAlarmOnlyDefaultConnectionRateLimit(), + ASSERT_TRUE(limiterManager.getAlarmOnlyDefaultConnectionRateLimit()); + EXPECT_EQ(*limiterManager.getAlarmOnlyDefaultConnectionRateLimit(), connectionLimit); EXPECT_TRUE(limiterManager.allowNewConnectionForVhost(vhostName)); EXPECT_TRUE(limiterManager.allowNewConnectionForVhost(vhostName)); @@ -283,24 +290,30 @@ TEST(ConnectionLimiterManagerTest, using namespace std::chrono_literals; ConnectionLimiterManager limiterManager; std::string vhostName = "test-vhost"; - EXPECT_EQ(limiterManager.getDefaultConnectionRateLimit(), 0); + EXPECT_FALSE(limiterManager.getDefaultConnectionRateLimit()); EXPECT_TRUE(limiterManager.allowNewConnectionForVhost(vhostName)); EXPECT_TRUE(limiterManager.allowNewConnectionForVhost(vhostName)); uint32_t connectionLimit = 1; limiterManager.setDefaultConnectionRateLimit(connectionLimit); - EXPECT_EQ(limiterManager.getDefaultConnectionRateLimit(), connectionLimit); + ASSERT_TRUE(limiterManager.getDefaultConnectionRateLimit()); + EXPECT_EQ(*limiterManager.getDefaultConnectionRateLimit(), + connectionLimit); EXPECT_TRUE(limiterManager.allowNewConnectionForVhost(vhostName)); EXPECT_FALSE(limiterManager.allowNewConnectionForVhost(vhostName)); uint32_t newConnectionLimit = 2; limiterManager.addConnectionRateLimiter(vhostName, newConnectionLimit); - EXPECT_EQ(limiterManager.getDefaultConnectionRateLimit(), connectionLimit); + ASSERT_TRUE(limiterManager.getDefaultConnectionRateLimit()); + EXPECT_EQ(*limiterManager.getDefaultConnectionRateLimit(), + connectionLimit); EXPECT_TRUE(limiterManager.allowNewConnectionForVhost(vhostName)); EXPECT_TRUE(limiterManager.allowNewConnectionForVhost(vhostName)); limiterManager.removeConnectionRateLimiter(vhostName); - EXPECT_EQ(limiterManager.getDefaultConnectionRateLimit(), connectionLimit); + ASSERT_TRUE(limiterManager.getDefaultConnectionRateLimit()); + EXPECT_EQ(*limiterManager.getDefaultConnectionRateLimit(), + connectionLimit); EXPECT_TRUE(limiterManager.allowNewConnectionForVhost(vhostName)); EXPECT_FALSE(limiterManager.allowNewConnectionForVhost(vhostName)); }