Skip to content

Commit

Permalink
Add context to Azure Service Bus Message callback (#43370)
Browse files Browse the repository at this point in the history
* Add context to Azure Service Bus Message callback

The original callback only took the message as a paramter. However,
users may want to push status or location information into XComs and
so callbacks need access to the context (or the XComs, but context is
more general). This commit changes the code to pass the context

NOTE: This is a BREAKING CHANGE.

Fixes 43361

* Add breaking change note to CHANGELOG
  • Loading branch information
perry2of5 authored Oct 26, 2024
1 parent fbc1c2b commit 93ad181
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 18 deletions.
6 changes: 6 additions & 0 deletions providers/src/airflow/providers/microsoft/azure/CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@
Changelog
---------

Breaking changes
~~~~~~~~~~~~~~~~
.. warning::
* We changed the message callback for Azure Service Bus messages to take two parameters, the message and the context, rather than just the message. This allows pushing message information into XComs. To upgrade from the previous version, which only took the message, please update your callback to take the context as a second parameter.


10.5.1
......

Expand Down
16 changes: 10 additions & 6 deletions providers/src/airflow/providers/microsoft/azure/hooks/asb.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,13 @@
get_sync_default_azure_credential,
)

MessageCallback = Callable[[ServiceBusMessage], None]


if TYPE_CHECKING:
from azure.identity import DefaultAzureCredential

from airflow.utils.context import Context

MessageCallback = Callable[[ServiceBusMessage, Context], None]


class BaseAzureServiceBusHook(BaseHook):
"""
Expand Down Expand Up @@ -283,6 +284,7 @@ def send_batch_message(sender: ServiceBusSender, messages: list[str]):
def receive_message(
self,
queue_name: str,
context: Context,
max_message_count: int | None = 1,
max_wait_time: float | None = None,
message_callback: MessageCallback | None = None,
Expand All @@ -309,12 +311,13 @@ def receive_message(
max_message_count=max_message_count, max_wait_time=max_wait_time
)
for msg in received_msgs:
self._process_message(msg, message_callback, receiver)
self._process_message(msg, context, message_callback, receiver)

def receive_subscription_message(
self,
topic_name: str,
subscription_name: str,
context: Context,
max_message_count: int | None,
max_wait_time: float | None,
message_callback: MessageCallback | None = None,
Expand Down Expand Up @@ -350,11 +353,12 @@ def receive_subscription_message(
max_message_count=max_message_count, max_wait_time=max_wait_time
)
for msg in received_msgs:
self._process_message(msg, message_callback, subscription_receiver)
self._process_message(msg, context, message_callback, subscription_receiver)

def _process_message(
self,
msg: ServiceBusReceivedMessage,
context: Context,
message_callback: MessageCallback | None,
receiver: ServiceBusReceiver,
):
Expand All @@ -372,7 +376,7 @@ def _process_message(
receiver.complete_message(msg)
else:
try:
message_callback(msg)
message_callback(msg, context)
except Exception as e:
self.log.error("Error processing message: %s", e)
receiver.abandon_message(msg)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

from airflow.utils.context import Context

MessageCallback = Callable[[ServiceBusMessage], None]
MessageCallback = Callable[[ServiceBusMessage, Context], None]


class AzureServiceBusCreateQueueOperator(BaseOperator):
Expand Down Expand Up @@ -176,6 +176,7 @@ def execute(self, context: Context) -> None:
# Receive message
hook.receive_message(
self.queue_name,
context,
max_message_count=self.max_message_count,
max_wait_time=self.max_wait_time,
message_callback=self.message_callback,
Expand Down Expand Up @@ -562,6 +563,7 @@ def execute(self, context: Context) -> None:
hook.receive_subscription_message(
self.topic_name,
self.subscription_name,
context,
self.max_message_count,
self.max_wait_time,
message_callback=self.message_callback,
Expand Down
22 changes: 16 additions & 6 deletions providers/tests/microsoft/azure/hooks/test_asb.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

from airflow.models import Connection
from airflow.providers.microsoft.azure.hooks.asb import AdminClientHook, MessageHook
from airflow.utils.context import Context

MESSAGE = "Test Message"
MESSAGE_LIST = [f"{MESSAGE} {n}" for n in range(10)]
Expand Down Expand Up @@ -256,7 +257,7 @@ def test_receive_message(self, mock_sb_client, mock_service_bus_message):
mock_sb_client.return_value.get_queue_receiver.return_value.receive_messages.return_value = [
mock_service_bus_message
]
hook.receive_message(self.queue_name)
hook.receive_message(self.queue_name, Context())
expected_calls = [
mock.call()
.__enter__()
Expand Down Expand Up @@ -285,12 +286,13 @@ def test_receive_message_callback(self, mock_sb_client, mock_service_bus_message

received_messages = []

def message_callback(msg: Any) -> None:
def message_callback(msg: Any, context: Context) -> None:
nonlocal received_messages
print("received message:", msg)
assert context is not None
received_messages.append(msg)

hook.receive_message(self.queue_name, message_callback=message_callback)
hook.receive_message(self.queue_name, Context(), message_callback=message_callback)

assert len(received_messages) == 1
assert received_messages[0] == mock_service_bus_message
Expand All @@ -316,7 +318,9 @@ def test_receive_subscription_message(self, mock_sb_client):
max_message_count = 10
max_wait_time = 5
hook = MessageHook(azure_service_bus_conn_id=self.conn_id)
hook.receive_subscription_message(topic_name, subscription_name, max_message_count, max_wait_time)
hook.receive_subscription_message(
topic_name, subscription_name, Context(), max_message_count, max_wait_time
)
expected_calls = [
mock.call()
.__enter__()
Expand Down Expand Up @@ -350,13 +354,19 @@ def test_receive_subscription_message_callback(self, mock_sb_client, mock_sb_mes

received_messages = []

def message_callback(msg: ServiceBusMessage) -> None:
def message_callback(msg: ServiceBusMessage, context: Context) -> None:
nonlocal received_messages
print("received message:", msg)
assert context is not None
received_messages.append(msg)

hook.receive_subscription_message(
topic_name, subscription_name, max_message_count, max_wait_time, message_callback=message_callback
topic_name,
subscription_name,
Context(),
max_message_count,
max_wait_time,
message_callback=message_callback,
)

assert len(received_messages) == 2
Expand Down
14 changes: 9 additions & 5 deletions providers/tests/microsoft/azure/operators/test_asb.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from unittest import mock

import pytest
from azure.servicebus import ServiceBusMessage

try:
from azure.servicebus import ServiceBusMessage
Expand All @@ -37,6 +38,7 @@
AzureServiceBusTopicDeleteOperator,
AzureServiceBusUpdateSubscriptionOperator,
)
from airflow.utils.context import Context

QUEUE_NAME = "test_queue"
MESSAGE = "Test Message"
Expand Down Expand Up @@ -216,21 +218,22 @@ def test_receive_message_queue_callback(self, mock_get_conn):
Test AzureServiceBusReceiveMessageOperator by mock connection, values
and the service bus receive message
"""
mock_service_bus_message = ServiceBusMessage("Test message")
mock_service_bus_message = ServiceBusMessage("Test message with context")
mock_get_conn.return_value.__enter__.return_value.get_queue_receiver.return_value.__enter__.return_value.receive_messages.return_value = [
mock_service_bus_message
]

messages_received = []

def message_callback(msg):
def message_callback(msg: ServiceBusMessage, context: Context):
messages_received.append(msg)
assert context is not None
print(msg)

asb_receive_queue_operator = AzureServiceBusReceiveMessageOperator(
task_id="asb_receive_message_queue", queue_name=QUEUE_NAME, message_callback=message_callback
)
asb_receive_queue_operator.execute(None)
asb_receive_queue_operator.execute(Context())
assert len(messages_received) == 1
assert messages_received[0] == mock_service_bus_message

Expand Down Expand Up @@ -470,8 +473,9 @@ def test_receive_message_queue_callback(self, mock_get_conn):

messages_received = []

def message_callback(msg):
def message_callback(msg: ServiceBusMessage, context: Context):
messages_received.append(msg)
assert context is not None
print(msg)

asb_subscription_receive_message = ASBReceiveSubscriptionMessageOperator(
Expand All @@ -482,7 +486,7 @@ def message_callback(msg):
message_callback=message_callback,
)

asb_subscription_receive_message.execute(None)
asb_subscription_receive_message.execute(Context())
expected_calls = [
mock.call()
.__enter__()
Expand Down

0 comments on commit 93ad181

Please sign in to comment.