From 8793230a1cd28b8c13bf74aa3d9e172514d533d7 Mon Sep 17 00:00:00 2001 From: Tim Perry Date: Thu, 24 Oct 2024 15:02:08 -0700 Subject: [PATCH] Add context to Azure Service Bus Message callback The original callback only took the message as a paramter. However, users may want to push status or location information into XComs and so callbacks need access to the context (or the XComs, but context is more general). This commit changes the code to pass the context NOTE: This is a BREAKING CHANGE. Fixes 43361 --- .../providers/microsoft/azure/hooks/asb.py | 16 +++++++++----- .../microsoft/azure/operators/asb.py | 4 +++- .../tests/microsoft/azure/hooks/test_asb.py | 22 ++++++++++++++----- .../microsoft/azure/operators/test_asb.py | 14 +++++++----- 4 files changed, 38 insertions(+), 18 deletions(-) diff --git a/providers/src/airflow/providers/microsoft/azure/hooks/asb.py b/providers/src/airflow/providers/microsoft/azure/hooks/asb.py index 317447d111703..1dafe3c7f3c87 100644 --- a/providers/src/airflow/providers/microsoft/azure/hooks/asb.py +++ b/providers/src/airflow/providers/microsoft/azure/hooks/asb.py @@ -34,12 +34,13 @@ get_sync_default_azure_credential, ) -MessageCallback = Callable[[ServiceBusMessage], None] - - if TYPE_CHECKING: from azure.identity import DefaultAzureCredential + from airflow.utils.context import Context + + MessageCallback = Callable[[ServiceBusMessage, Context], None] + class BaseAzureServiceBusHook(BaseHook): """ @@ -283,6 +284,7 @@ def send_batch_message(sender: ServiceBusSender, messages: list[str]): def receive_message( self, queue_name: str, + context: Context, max_message_count: int | None = 1, max_wait_time: float | None = None, message_callback: MessageCallback | None = None, @@ -309,12 +311,13 @@ def receive_message( max_message_count=max_message_count, max_wait_time=max_wait_time ) for msg in received_msgs: - self._process_message(msg, message_callback, receiver) + self._process_message(msg, context, message_callback, receiver) def receive_subscription_message( self, topic_name: str, subscription_name: str, + context: Context, max_message_count: int | None, max_wait_time: float | None, message_callback: MessageCallback | None = None, @@ -350,11 +353,12 @@ def receive_subscription_message( max_message_count=max_message_count, max_wait_time=max_wait_time ) for msg in received_msgs: - self._process_message(msg, message_callback, subscription_receiver) + self._process_message(msg, context, message_callback, subscription_receiver) def _process_message( self, msg: ServiceBusReceivedMessage, + context: Context, message_callback: MessageCallback | None, receiver: ServiceBusReceiver, ): @@ -372,7 +376,7 @@ def _process_message( receiver.complete_message(msg) else: try: - message_callback(msg) + message_callback(msg, context) except Exception as e: self.log.error("Error processing message: %s", e) receiver.abandon_message(msg) diff --git a/providers/src/airflow/providers/microsoft/azure/operators/asb.py b/providers/src/airflow/providers/microsoft/azure/operators/asb.py index 85619526cfb93..7d6bab0d625f6 100644 --- a/providers/src/airflow/providers/microsoft/azure/operators/asb.py +++ b/providers/src/airflow/providers/microsoft/azure/operators/asb.py @@ -31,7 +31,7 @@ from airflow.utils.context import Context - MessageCallback = Callable[[ServiceBusMessage], None] + MessageCallback = Callable[[ServiceBusMessage, Context], None] class AzureServiceBusCreateQueueOperator(BaseOperator): @@ -176,6 +176,7 @@ def execute(self, context: Context) -> None: # Receive message hook.receive_message( self.queue_name, + context, max_message_count=self.max_message_count, max_wait_time=self.max_wait_time, message_callback=self.message_callback, @@ -562,6 +563,7 @@ def execute(self, context: Context) -> None: hook.receive_subscription_message( self.topic_name, self.subscription_name, + context, self.max_message_count, self.max_wait_time, message_callback=self.message_callback, diff --git a/providers/tests/microsoft/azure/hooks/test_asb.py b/providers/tests/microsoft/azure/hooks/test_asb.py index 83e04833bf07f..6d090e6653dd0 100644 --- a/providers/tests/microsoft/azure/hooks/test_asb.py +++ b/providers/tests/microsoft/azure/hooks/test_asb.py @@ -33,6 +33,7 @@ from airflow.models import Connection from airflow.providers.microsoft.azure.hooks.asb import AdminClientHook, MessageHook +from airflow.utils.context import Context MESSAGE = "Test Message" MESSAGE_LIST = [f"{MESSAGE} {n}" for n in range(10)] @@ -256,7 +257,7 @@ def test_receive_message(self, mock_sb_client, mock_service_bus_message): mock_sb_client.return_value.get_queue_receiver.return_value.receive_messages.return_value = [ mock_service_bus_message ] - hook.receive_message(self.queue_name) + hook.receive_message(self.queue_name, Context()) expected_calls = [ mock.call() .__enter__() @@ -285,12 +286,13 @@ def test_receive_message_callback(self, mock_sb_client, mock_service_bus_message received_messages = [] - def message_callback(msg: Any) -> None: + def message_callback(msg: Any, context: Context) -> None: nonlocal received_messages print("received message:", msg) + assert context is not None received_messages.append(msg) - hook.receive_message(self.queue_name, message_callback=message_callback) + hook.receive_message(self.queue_name, Context(), message_callback=message_callback) assert len(received_messages) == 1 assert received_messages[0] == mock_service_bus_message @@ -316,7 +318,9 @@ def test_receive_subscription_message(self, mock_sb_client): max_message_count = 10 max_wait_time = 5 hook = MessageHook(azure_service_bus_conn_id=self.conn_id) - hook.receive_subscription_message(topic_name, subscription_name, max_message_count, max_wait_time) + hook.receive_subscription_message( + topic_name, subscription_name, Context(), max_message_count, max_wait_time + ) expected_calls = [ mock.call() .__enter__() @@ -350,13 +354,19 @@ def test_receive_subscription_message_callback(self, mock_sb_client, mock_sb_mes received_messages = [] - def message_callback(msg: ServiceBusMessage) -> None: + def message_callback(msg: ServiceBusMessage, context: Context) -> None: nonlocal received_messages print("received message:", msg) + assert context is not None received_messages.append(msg) hook.receive_subscription_message( - topic_name, subscription_name, max_message_count, max_wait_time, message_callback=message_callback + topic_name, + subscription_name, + Context(), + max_message_count, + max_wait_time, + message_callback=message_callback, ) assert len(received_messages) == 2 diff --git a/providers/tests/microsoft/azure/operators/test_asb.py b/providers/tests/microsoft/azure/operators/test_asb.py index 42b770095b4e7..7e0c953890c22 100644 --- a/providers/tests/microsoft/azure/operators/test_asb.py +++ b/providers/tests/microsoft/azure/operators/test_asb.py @@ -19,6 +19,7 @@ from unittest import mock import pytest +from azure.servicebus import ServiceBusMessage try: from azure.servicebus import ServiceBusMessage @@ -37,6 +38,7 @@ AzureServiceBusTopicDeleteOperator, AzureServiceBusUpdateSubscriptionOperator, ) +from airflow.utils.context import Context QUEUE_NAME = "test_queue" MESSAGE = "Test Message" @@ -216,21 +218,22 @@ def test_receive_message_queue_callback(self, mock_get_conn): Test AzureServiceBusReceiveMessageOperator by mock connection, values and the service bus receive message """ - mock_service_bus_message = ServiceBusMessage("Test message") + mock_service_bus_message = ServiceBusMessage("Test message with context") mock_get_conn.return_value.__enter__.return_value.get_queue_receiver.return_value.__enter__.return_value.receive_messages.return_value = [ mock_service_bus_message ] messages_received = [] - def message_callback(msg): + def message_callback(msg: ServiceBusMessage, context: Context): messages_received.append(msg) + assert context is not None print(msg) asb_receive_queue_operator = AzureServiceBusReceiveMessageOperator( task_id="asb_receive_message_queue", queue_name=QUEUE_NAME, message_callback=message_callback ) - asb_receive_queue_operator.execute(None) + asb_receive_queue_operator.execute(Context()) assert len(messages_received) == 1 assert messages_received[0] == mock_service_bus_message @@ -470,8 +473,9 @@ def test_receive_message_queue_callback(self, mock_get_conn): messages_received = [] - def message_callback(msg): + def message_callback(msg: ServiceBusMessage, context: Context): messages_received.append(msg) + assert context is not None print(msg) asb_subscription_receive_message = ASBReceiveSubscriptionMessageOperator( @@ -482,7 +486,7 @@ def message_callback(msg): message_callback=message_callback, ) - asb_subscription_receive_message.execute(None) + asb_subscription_receive_message.execute(Context()) expected_calls = [ mock.call() .__enter__()