Skip to content

Commit

Permalink
Refactor git-agent to use on-message callback instead of async loop
Browse files Browse the repository at this point in the history
* Also adds acks for messages (this causes the change to have only one
  of these messages processed at the same type due to the prefetch_count
  of 1 set for the git-agent)
* Changes the code in the git-agent so that the shutdown process is a
  bit more graceful
* In a later version we'll look at changing the behaviour of the
  shutdown process so that we continue working on any active messages
  before everything is shutdown.
  • Loading branch information
ogenstad committed Mar 18, 2024
1 parent 8c8583a commit 501a1f8
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 44 deletions.
18 changes: 12 additions & 6 deletions backend/infrahub/cli/git_agent.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import asyncio
import logging
import signal
import sys
from asyncio import run as aiorun
from typing import Any

import typer
Expand All @@ -26,10 +25,11 @@

log = get_logger()

shutdown_event = asyncio.Event()


def signal_handler(*args: Any, **kwargs: Any) -> None: # pylint: disable=unused-argument
print("Git Agent terminated by user.")
sys.exit(0)
shutdown_event.set()


signal.signal(signal.SIGINT, signal_handler)
Expand Down Expand Up @@ -87,7 +87,13 @@ async def _start(debug: bool, port: int) -> None:

build_component_registry()

await service.message_bus.subscribe()
while not shutdown_event.is_set():
await asyncio.sleep(1)

log.info("Shutdown of Git agent requested")

await service.shutdown()
log.info("All services stopped")


@app.command()
Expand All @@ -110,4 +116,4 @@ def start(

config.load_and_exit(config_file_name=config_file)

aiorun(_start(debug=debug, port=port))
asyncio.run(_start(debug=debug, port=port))
3 changes: 0 additions & 3 deletions backend/infrahub/services/adapters/message_bus/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,3 @@ async def reply(self, message: InfrahubMessage, routing_key: str) -> None:

async def rpc(self, message: InfrahubMessage, response_class: type[ResponseClass]) -> ResponseClass:
raise NotImplementedError()

async def subscribe(self) -> None:
raise NotImplementedError()
37 changes: 14 additions & 23 deletions backend/infrahub/services/adapters/message_bus/rabbitmq.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,14 @@ async def on_callback(self, message: AbstractIncomingMessage) -> None:
else:
self.service.log.error("Invalid message received", message=f"{message!r}")

async def on_message(self, message: AbstractIncomingMessage) -> None:
async with message.process():
clear_log_context()
if message.routing_key in messages.MESSAGE_MAP:
await execute_message(routing_key=message.routing_key, message_body=message.body, service=self.service)
else:
self.service.log.error("Invalid message received", message=f"{message!r}")

async def _initialize_api_server(self) -> None:
self.callback_queue = await self.channel.declare_queue(name=f"api-callback-{WORKER_IDENTITY}", exclusive=True)
self.events_queue = await self.channel.declare_queue(name=f"api-events-{WORKER_IDENTITY}", exclusive=True)
Expand Down Expand Up @@ -142,7 +150,6 @@ async def _initialize_api_server(self) -> None:
self.message_enrichers.append(_add_request_id)

async def _initialize_git_worker(self) -> None:
await self.channel.set_qos(prefetch_count=1)
events_queue = await self.channel.declare_queue(name=f"worker-events-{WORKER_IDENTITY}", exclusive=True)

self.exchange = await self.channel.declare_exchange(
Expand All @@ -157,6 +164,12 @@ async def _initialize_git_worker(self) -> None:
)
await self.callback_queue.consume(self.on_callback, no_ack=True)

message_channel = await self.connection.channel()
await message_channel.set_qos(prefetch_count=2)

queue = await message_channel.get_queue(f"{self.settings.namespace}.rpcs")
await queue.consume(callback=self.on_message, no_ack=False)

async def publish(self, message: InfrahubMessage, routing_key: str, delay: Optional[MessageTTL] = None) -> None:
for enricher in self.message_enrichers:
await enricher(message)
Expand Down Expand Up @@ -186,28 +199,6 @@ async def rpc(self, message: InfrahubMessage, response_class: Type[ResponseClass
data = json.loads(response.body)
return response_class(**data)

async def subscribe(self) -> None:
queue = await self.channel.get_queue(f"{self.settings.namespace}.rpcs")
self.service.log.info("Waiting for RPC instructions to execute .. ")
async with queue.iterator() as qiterator:
async for message in qiterator:
try:
async with message.process(requeue=False):
clear_log_context()
if message.routing_key in messages.MESSAGE_MAP:
await execute_message(
routing_key=message.routing_key, message_body=message.body, service=self.service
)
else:
self.service.log.error(
"Unhandled routing key for message",
routing_key=message.routing_key,
message=message.body,
)

except Exception: # pylint: disable=broad-except
self.service.log.exception("Processing error for message %r" % message)

@staticmethod
def format_message(message: InfrahubMessage) -> aio_pika.Message:
pika_message = aio_pika.Message(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -431,8 +431,8 @@ async def test_rabbitmq_rpc(rabbitmq_api: RabbitMQManager, fake_log: FakeLogger)
assert response.data.response == "Reply to: You can reply to this message"


async def test_rabbitmq_subscribe(rabbitmq_api: RabbitMQManager, fake_log: FakeLogger) -> None:
"""Validates the subscribe method."""
async def test_rabbitmq_on_message(rabbitmq_api: RabbitMQManager, fake_log: FakeLogger) -> None:
"""Validates the on_message method."""

bus = RabbitMQMessageBus(settings=rabbitmq_api.settings)
api_service = InfrahubServices(message_bus=bus, component_type=ComponentType.API_SERVER)
Expand All @@ -442,18 +442,15 @@ async def test_rabbitmq_subscribe(rabbitmq_api: RabbitMQManager, fake_log: FakeL

await bus.initialize(service=agent_service)

subscribe_task = asyncio.create_task(bus.subscribe())

await agent_service.send(message=messages.SendEchoRequest(message="Hello there"))
await asyncio.sleep(delay=1)
await bus.shutdown()
subscribe_task.cancel()

assert fake_log.info_logs == ["Waiting for RPC instructions to execute .. ", "Received message: Hello there"]
assert fake_log.info_logs == ["Received message: Hello there"]
assert fake_log.error_logs == []


async def test_rabbitmq_subscribe_invalid_routing_key(rabbitmq_api: RabbitMQManager, fake_log: FakeLogger) -> None:
async def test_rabbitmq_on_message_invalid_routing_key(rabbitmq_api: RabbitMQManager, fake_log: FakeLogger) -> None:
"""Validates logging of invalid routing key"""

bus = RabbitMQMessageBus(settings=rabbitmq_api.settings)
Expand All @@ -464,15 +461,12 @@ async def test_rabbitmq_subscribe_invalid_routing_key(rabbitmq_api: RabbitMQMana

await bus.initialize(service=agent_service)

subscribe_task = asyncio.create_task(bus.subscribe())

await bus.publish(routing_key="request.something.invalid", message=messages.SendEchoRequest(message="Hello there"))
await asyncio.sleep(delay=1)
await bus.shutdown()
subscribe_task.cancel()

assert fake_log.info_logs == ["Waiting for RPC instructions to execute .. "]
assert fake_log.error_logs == ["Unhandled routing key for message"]
assert fake_log.info_logs == []
assert fake_log.error_logs == ["Invalid message received"]


async def on_callback(message: AbstractIncomingMessage, service: InfrahubServices) -> None:
Expand Down

0 comments on commit 501a1f8

Please sign in to comment.