diff --git a/config/settings/base.py b/config/settings/base.py index 2895e5720..aefe4b0bf 100644 --- a/config/settings/base.py +++ b/config/settings/base.py @@ -522,6 +522,9 @@ # ------------------------------------------------------------------------------ EVENTS_QUEUE_URL = env("EVENTS_QUEUE_URL", default=None) EVENTS_QUEUE_EXCHANGE_NAME = env("EVENTS_QUEUE_EXCHANGE_NAME", default="amq.fanout") +EVENTS_QUEUE_POOL_CONNECTIONS_LIMIT = env.int( + "EVENTS_QUEUE_POOL_CONNECTIONS_LIMIT", default=0 +) # Cache CACHE_ALL_TXS_VIEW = env.int( diff --git a/safe_transaction_service/events/services/queue_service.py b/safe_transaction_service/events/services/queue_service.py index eb3d9eb7d..5c38f5212 100644 --- a/safe_transaction_service/events/services/queue_service.py +++ b/safe_transaction_service/events/services/queue_service.py @@ -76,26 +76,43 @@ def get_queue_service(): class QueueService: def __init__(self): self._connection_pool: List[BrokerConnection] = [] + self._total_connections: int = 0 self.unsent_events: List = [] - def get_connection(self) -> BrokerConnection: + def get_connection(self) -> Optional[BrokerConnection]: """ :return: A `BrokerConnection` from the connection pool if there is one available, othwerwise returns a new BrokerConnection """ + if ( + settings.EVENTS_QUEUE_POOL_CONNECTIONS_LIMIT + and self._total_connections >= settings.EVENTS_QUEUE_POOL_CONNECTIONS_LIMIT + ): + logger.warning( + "Number of active connections reached the pool limit: %d", + self._total_connections, + ) + return None + if self._connection_pool: - return self._connection_pool.pop() + broker_connection = self._connection_pool.pop() else: - return BrokerConnection() + broker_connection = BrokerConnection() + + self._total_connections += 1 + return broker_connection - def release_connection(self, broker_connection: BrokerConnection) -> None: + def release_connection(self, broker_connection: Optional[BrokerConnection]): """ Return the `BrokerConnection` to the pool :param broker_connection: :return: """ - return self._connection_pool.insert(0, broker_connection) + self._total_connections -= 1 + # Don't add broken connections to the pool + if broker_connection: + self._connection_pool.insert(0, broker_connection) def send_event(self, payload: Dict[str, Any]) -> int: """ @@ -103,9 +120,12 @@ def send_event(self, payload: Dict[str, Any]) -> int: :param payload: Number of events published """ - broker_connection = self.get_connection() - event = json.dumps(payload) + if not (broker_connection := self.get_connection()): + # No available connections in the pool, store event to send it later + self.unsent_events.append(event) + return 0 + if broker_connection.publish(event): logger.debug("Event correctly sent: %s", event) self.release_connection(broker_connection) @@ -114,6 +134,8 @@ def send_event(self, payload: Dict[str, Any]) -> int: logger.warning("Unable to send the event due to a connection error") logger.debug("Adding %s to unsent messages", payload) self.unsent_events.append(event) + # As the message cannot be sent, we don't want to send the problematic connection back to the pool, only reduce the number of total connections + self.release_connection(None) return 0 def send_unsent_events(self) -> int: @@ -125,7 +147,9 @@ def send_unsent_events(self) -> int: if not self.unsent_events: return 0 - broker_connection = self.get_connection() + if not (broker_connection := self.get_connection()): + # Connection not available in the pool + return 0 # Avoid race conditions unsent_events = self.unsent_events diff --git a/safe_transaction_service/events/tests/test_queue_service.py b/safe_transaction_service/events/tests/test_queue_service.py index e37a6648d..5607c302b 100644 --- a/safe_transaction_service/events/tests/test_queue_service.py +++ b/safe_transaction_service/events/tests/test_queue_service.py @@ -6,7 +6,7 @@ from pika.channel import Channel from pika.exceptions import ConnectionClosedByBroker -from ..services.queue_service import BrokerConnection, get_queue_service +from ..services.queue_service import BrokerConnection, QueueService, get_queue_service class TestQueueService(TestCase): @@ -58,10 +58,27 @@ def test_send_unsent_messages(self): _, _, body = broker_connection.channel.basic_get(self.queue, auto_ack=True) self.assertEqual(json.loads(body), payload) + def test_send_with_pool_limit(self): + queue_service = QueueService() + payload = "Pool limit test" + # Unused connection, just to reach the limit + connection_1 = queue_service.get_connection() + self.assertEqual(len(queue_service.unsent_events), 0) + self.assertEqual(queue_service.send_event(payload), 1) + with self.settings(EVENTS_QUEUE_POOL_CONNECTIONS_LIMIT=1): + self.assertEqual(queue_service._total_connections, 1) + self.assertEqual(len(queue_service.unsent_events), 0) + self.assertEqual(queue_service.send_event(payload), 0) + self.assertEqual(len(queue_service.unsent_events), 1) + queue_service.release_connection(connection_1) + self.assertEqual(len(queue_service.unsent_events), 1) + self.assertEqual(queue_service.send_event(payload), 2) + self.assertEqual(len(queue_service.unsent_events), 0) + def test_send_event_to_queue(self): payload = {"event": "test_event", "type": "event type"} - queue_service = get_queue_service() - # Clean previous pool connections + queue_service = QueueService() + # Clean previous connection pool queue_service._connection_pool = [] self.assertEqual(len(queue_service._connection_pool), 0) queue_service.send_event(payload) @@ -72,15 +89,29 @@ def test_send_event_to_queue(self): self.assertEqual(json.loads(body), payload) def test_get_connection(self): - queue_service = get_queue_service() - # Clean previous pool connections + queue_service = QueueService() + # Clean previous connection pool queue_service._connection_pool = [] self.assertEqual(len(queue_service._connection_pool), 0) + self.assertEqual(queue_service._total_connections, 0) connection_1 = queue_service.get_connection() self.assertEqual(len(queue_service._connection_pool), 0) + self.assertEqual(queue_service._total_connections, 1) connection_2 = queue_service.get_connection() self.assertEqual(len(queue_service._connection_pool), 0) + self.assertEqual(queue_service._total_connections, 2) queue_service.release_connection(connection_1) self.assertEqual(len(queue_service._connection_pool), 1) + self.assertEqual(queue_service._total_connections, 1) queue_service.release_connection(connection_2) self.assertEqual(len(queue_service._connection_pool), 2) + self.assertEqual(queue_service._total_connections, 0) + with self.settings(EVENTS_QUEUE_POOL_CONNECTIONS_LIMIT=1): + connection_1 = queue_service.get_connection() + self.assertEqual(len(queue_service._connection_pool), 1) + self.assertEqual(queue_service._total_connections, 1) + # We should reach the connection limit of the pool + connection_1 = queue_service.get_connection() + self.assertEqual(len(queue_service._connection_pool), 1) + self.assertEqual(queue_service._total_connections, 1) + self.assertIsNone(connection_1)