Skip to content

Commit

Permalink
Add context to Azure Service Bus Message callback
Browse files Browse the repository at this point in the history
The original callback only took the message as a paramter. However,
users may want to push status or location information into XComs and
so callbacks need access to the context (or the XComs, but context is
more general). This commit changes the code to pass the context

NOTE: This is a BREAKING CHANGE.

Fixes 43361
  • Loading branch information
perry2of5 committed Oct 24, 2024
1 parent 7a15849 commit 8793230
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 18 deletions.
16 changes: 10 additions & 6 deletions providers/src/airflow/providers/microsoft/azure/hooks/asb.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,13 @@
get_sync_default_azure_credential,
)

MessageCallback = Callable[[ServiceBusMessage], None]


if TYPE_CHECKING:
from azure.identity import DefaultAzureCredential

from airflow.utils.context import Context

MessageCallback = Callable[[ServiceBusMessage, Context], None]


class BaseAzureServiceBusHook(BaseHook):
"""
Expand Down Expand Up @@ -283,6 +284,7 @@ def send_batch_message(sender: ServiceBusSender, messages: list[str]):
def receive_message(
self,
queue_name: str,
context: Context,
max_message_count: int | None = 1,
max_wait_time: float | None = None,
message_callback: MessageCallback | None = None,
Expand All @@ -309,12 +311,13 @@ def receive_message(
max_message_count=max_message_count, max_wait_time=max_wait_time
)
for msg in received_msgs:
self._process_message(msg, message_callback, receiver)
self._process_message(msg, context, message_callback, receiver)

def receive_subscription_message(
self,
topic_name: str,
subscription_name: str,
context: Context,
max_message_count: int | None,
max_wait_time: float | None,
message_callback: MessageCallback | None = None,
Expand Down Expand Up @@ -350,11 +353,12 @@ def receive_subscription_message(
max_message_count=max_message_count, max_wait_time=max_wait_time
)
for msg in received_msgs:
self._process_message(msg, message_callback, subscription_receiver)
self._process_message(msg, context, message_callback, subscription_receiver)

def _process_message(
self,
msg: ServiceBusReceivedMessage,
context: Context,
message_callback: MessageCallback | None,
receiver: ServiceBusReceiver,
):
Expand All @@ -372,7 +376,7 @@ def _process_message(
receiver.complete_message(msg)
else:
try:
message_callback(msg)
message_callback(msg, context)
except Exception as e:
self.log.error("Error processing message: %s", e)
receiver.abandon_message(msg)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

from airflow.utils.context import Context

MessageCallback = Callable[[ServiceBusMessage], None]
MessageCallback = Callable[[ServiceBusMessage, Context], None]


class AzureServiceBusCreateQueueOperator(BaseOperator):
Expand Down Expand Up @@ -176,6 +176,7 @@ def execute(self, context: Context) -> None:
# Receive message
hook.receive_message(
self.queue_name,
context,
max_message_count=self.max_message_count,
max_wait_time=self.max_wait_time,
message_callback=self.message_callback,
Expand Down Expand Up @@ -562,6 +563,7 @@ def execute(self, context: Context) -> None:
hook.receive_subscription_message(
self.topic_name,
self.subscription_name,
context,
self.max_message_count,
self.max_wait_time,
message_callback=self.message_callback,
Expand Down
22 changes: 16 additions & 6 deletions providers/tests/microsoft/azure/hooks/test_asb.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

from airflow.models import Connection
from airflow.providers.microsoft.azure.hooks.asb import AdminClientHook, MessageHook
from airflow.utils.context import Context

MESSAGE = "Test Message"
MESSAGE_LIST = [f"{MESSAGE} {n}" for n in range(10)]
Expand Down Expand Up @@ -256,7 +257,7 @@ def test_receive_message(self, mock_sb_client, mock_service_bus_message):
mock_sb_client.return_value.get_queue_receiver.return_value.receive_messages.return_value = [
mock_service_bus_message
]
hook.receive_message(self.queue_name)
hook.receive_message(self.queue_name, Context())
expected_calls = [
mock.call()
.__enter__()
Expand Down Expand Up @@ -285,12 +286,13 @@ def test_receive_message_callback(self, mock_sb_client, mock_service_bus_message

received_messages = []

def message_callback(msg: Any) -> None:
def message_callback(msg: Any, context: Context) -> None:
nonlocal received_messages
print("received message:", msg)
assert context is not None
received_messages.append(msg)

hook.receive_message(self.queue_name, message_callback=message_callback)
hook.receive_message(self.queue_name, Context(), message_callback=message_callback)

assert len(received_messages) == 1
assert received_messages[0] == mock_service_bus_message
Expand All @@ -316,7 +318,9 @@ def test_receive_subscription_message(self, mock_sb_client):
max_message_count = 10
max_wait_time = 5
hook = MessageHook(azure_service_bus_conn_id=self.conn_id)
hook.receive_subscription_message(topic_name, subscription_name, max_message_count, max_wait_time)
hook.receive_subscription_message(
topic_name, subscription_name, Context(), max_message_count, max_wait_time
)
expected_calls = [
mock.call()
.__enter__()
Expand Down Expand Up @@ -350,13 +354,19 @@ def test_receive_subscription_message_callback(self, mock_sb_client, mock_sb_mes

received_messages = []

def message_callback(msg: ServiceBusMessage) -> None:
def message_callback(msg: ServiceBusMessage, context: Context) -> None:
nonlocal received_messages
print("received message:", msg)
assert context is not None
received_messages.append(msg)

hook.receive_subscription_message(
topic_name, subscription_name, max_message_count, max_wait_time, message_callback=message_callback
topic_name,
subscription_name,
Context(),
max_message_count,
max_wait_time,
message_callback=message_callback,
)

assert len(received_messages) == 2
Expand Down
14 changes: 9 additions & 5 deletions providers/tests/microsoft/azure/operators/test_asb.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from unittest import mock

import pytest
from azure.servicebus import ServiceBusMessage

try:
from azure.servicebus import ServiceBusMessage
Expand All @@ -37,6 +38,7 @@
AzureServiceBusTopicDeleteOperator,
AzureServiceBusUpdateSubscriptionOperator,
)
from airflow.utils.context import Context

QUEUE_NAME = "test_queue"
MESSAGE = "Test Message"
Expand Down Expand Up @@ -216,21 +218,22 @@ def test_receive_message_queue_callback(self, mock_get_conn):
Test AzureServiceBusReceiveMessageOperator by mock connection, values
and the service bus receive message
"""
mock_service_bus_message = ServiceBusMessage("Test message")
mock_service_bus_message = ServiceBusMessage("Test message with context")
mock_get_conn.return_value.__enter__.return_value.get_queue_receiver.return_value.__enter__.return_value.receive_messages.return_value = [
mock_service_bus_message
]

messages_received = []

def message_callback(msg):
def message_callback(msg: ServiceBusMessage, context: Context):
messages_received.append(msg)
assert context is not None
print(msg)

asb_receive_queue_operator = AzureServiceBusReceiveMessageOperator(
task_id="asb_receive_message_queue", queue_name=QUEUE_NAME, message_callback=message_callback
)
asb_receive_queue_operator.execute(None)
asb_receive_queue_operator.execute(Context())
assert len(messages_received) == 1
assert messages_received[0] == mock_service_bus_message

Expand Down Expand Up @@ -470,8 +473,9 @@ def test_receive_message_queue_callback(self, mock_get_conn):

messages_received = []

def message_callback(msg):
def message_callback(msg: ServiceBusMessage, context: Context):
messages_received.append(msg)
assert context is not None
print(msg)

asb_subscription_receive_message = ASBReceiveSubscriptionMessageOperator(
Expand All @@ -482,7 +486,7 @@ def message_callback(msg):
message_callback=message_callback,
)

asb_subscription_receive_message.execute(None)
asb_subscription_receive_message.execute(Context())
expected_calls = [
mock.call()
.__enter__()
Expand Down

0 comments on commit 8793230

Please sign in to comment.