diff --git a/providers/amazon/src/airflow/providers/amazon/aws/hooks/sns.py b/providers/amazon/src/airflow/providers/amazon/aws/hooks/sns.py index 00315fbddc370..de309df4815db 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/sns.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/sns.py @@ -22,6 +22,7 @@ import json from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook +from airflow.utils.helpers import prune_dict def _get_message_attribute(o): @@ -38,6 +39,33 @@ def _get_message_attribute(o): ) +def _build_publish_kwargs( + target_arn: str, + message: str, + subject: str | None = None, + message_attributes: dict | None = None, + message_deduplication_id: str | None = None, + message_group_id: str | None = None, +) -> dict[str, str | dict]: + publish_kwargs: dict[str, str | dict] = prune_dict( + { + "TargetArn": target_arn, + "MessageStructure": "json", + "Message": json.dumps({"default": message}), + "Subject": subject, + "MessageDeduplicationId": message_deduplication_id, + "MessageGroupId": message_group_id, + } + ) + + if message_attributes: + publish_kwargs["MessageAttributes"] = { + key: _get_message_attribute(val) for key, val in message_attributes.items() + } + + return publish_kwargs + + class SnsHook(AwsBaseHook): """ Interact with Amazon Simple Notification Service. @@ -84,22 +112,50 @@ def publish_to_target( :param message_group_id: Tag that specifies that a message belongs to a specific message group. This parameter applies only to FIFO (first-in-first-out) topics. """ - publish_kwargs: dict[str, str | dict] = { - "TargetArn": target_arn, - "MessageStructure": "json", - "Message": json.dumps({"default": message}), - } + return self.get_conn().publish( + **_build_publish_kwargs( + target_arn, message, subject, message_attributes, message_deduplication_id, message_group_id + ) + ) - # Construct args this way because boto3 distinguishes from missing args and those set to None - if subject: - publish_kwargs["Subject"] = subject - if message_deduplication_id: - publish_kwargs["MessageDeduplicationId"] = message_deduplication_id - if message_group_id: - publish_kwargs["MessageGroupId"] = message_group_id - if message_attributes: - publish_kwargs["MessageAttributes"] = { - key: _get_message_attribute(val) for key, val in message_attributes.items() - } - - return self.get_conn().publish(**publish_kwargs) + async def apublish_to_target( + self, + target_arn: str, + message: str, + subject: str | None = None, + message_attributes: dict | None = None, + message_deduplication_id: str | None = None, + message_group_id: str | None = None, + ): + """ + Publish a message to a SNS topic or an endpoint. + + .. seealso:: + - :external+boto3:py:meth:`SNS.Client.publish` + + :param target_arn: either a TopicArn or an EndpointArn + :param message: the default message you want to send + :param subject: subject of message + :param message_attributes: additional attributes to publish for message filtering. This should be + a flat dict; the DataType to be sent depends on the type of the value: + + - bytes = Binary + - str = String + - int, float = Number + - iterable = String.Array + :param message_deduplication_id: Every message must have a unique message_deduplication_id. + This parameter applies only to FIFO (first-in-first-out) topics. + :param message_group_id: Tag that specifies that a message belongs to a specific message group. + This parameter applies only to FIFO (first-in-first-out) topics. + """ + async with await self.get_async_conn() as async_client: + return await async_client.publish( + **_build_publish_kwargs( + target_arn, + message, + subject, + message_attributes, + message_deduplication_id, + message_group_id, + ) + ) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/notifications/sns.py b/providers/amazon/src/airflow/providers/amazon/aws/notifications/sns.py index c73d52e85cc46..62a3c037a7f1f 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/notifications/sns.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/notifications/sns.py @@ -21,6 +21,7 @@ from functools import cached_property from airflow.providers.amazon.aws.hooks.sns import SnsHook +from airflow.providers.amazon.version_compat import AIRFLOW_V_3_1_PLUS from airflow.providers.common.compat.notifier import BaseNotifier @@ -60,8 +61,13 @@ def __init__( subject: str | None = None, message_attributes: dict | None = None, region_name: str | None = None, + **kwargs, ): - super().__init__() + if AIRFLOW_V_3_1_PLUS: + # Support for passing context was added in 3.1.0 + super().__init__(**kwargs) + else: + super().__init__() self.aws_conn_id = aws_conn_id self.region_name = region_name self.target_arn = target_arn @@ -83,5 +89,14 @@ def notify(self, context): message_attributes=self.message_attributes, ) + async def async_notify(self, context): + """Publish the notification message to Amazon SNS (async).""" + await self.hook.apublish_to_target( + target_arn=self.target_arn, + message=self.message, + subject=self.subject, + message_attributes=self.message_attributes, + ) + send_sns_notification = SnsNotifier diff --git a/providers/amazon/tests/unit/amazon/aws/hooks/test_sns.py b/providers/amazon/tests/unit/amazon/aws/hooks/test_sns.py index 19043cd5f45fb..e39157ca1f1d4 100644 --- a/providers/amazon/tests/unit/amazon/aws/hooks/test_sns.py +++ b/providers/amazon/tests/unit/amazon/aws/hooks/test_sns.py @@ -17,95 +17,75 @@ # under the License. from __future__ import annotations +from unittest import mock + import pytest from moto import mock_aws from airflow.providers.amazon.aws.hooks.sns import SnsHook +DEDUPE_ID = "test-dedupe-id" +GROUP_ID = "test-group-id" MESSAGE = "Hello world" -TOPIC_NAME = "test-topic" SUBJECT = "test-subject" +INVALID_ATTRIBUTES_MSG = r"Values in MessageAttributes must be one of bytes, str, int, float, or iterable" +TOPIC_NAME = "test-topic" +TOPIC_ARN = f"arn:aws:sns:us-east-1:123456789012:{TOPIC_NAME}" -@mock_aws -class TestSnsHook: - def test_get_conn_returns_a_boto3_connection(self): - hook = SnsHook(aws_conn_id="aws_default") - assert hook.get_conn() is not None - - def test_publish_to_target_with_subject(self): - hook = SnsHook(aws_conn_id="aws_default") - - message = MESSAGE - topic_name = TOPIC_NAME - subject = SUBJECT - target = hook.get_conn().create_topic(Name=topic_name).get("TopicArn") - - response = hook.publish_to_target(target, message, subject) +INVALID_ATTRIBUTES = {"test-non-iterable": object()} +VALID_ATTRIBUTES = { + "test-string": "string-value", + "test-number": 123456, + "test-array": ["first", "second", "third"], + "test-binary": b"binary-value", +} - assert "MessageId" in response +MESSAGE_ID_KEY = "MessageId" +TOPIC_ARN_KEY = "TopicArn" - def test_publish_to_target_with_attributes(self): - hook = SnsHook(aws_conn_id="aws_default") - message = MESSAGE - topic_name = TOPIC_NAME - target = hook.get_conn().create_topic(Name=topic_name).get("TopicArn") +class TestSnsHook: + @pytest.fixture(autouse=True) + def setup_moto(self): + with mock_aws(): + yield - response = hook.publish_to_target( - target, - message, - message_attributes={ - "test-string": "string-value", - "test-number": 123456, - "test-array": ["first", "second", "third"], - "test-binary": b"binary-value", - }, - ) + @pytest.fixture + def hook(self): + return SnsHook(aws_conn_id="aws_default") - assert "MessageId" in response + @pytest.fixture + def target(self, hook): + return hook.get_conn().create_topic(Name=TOPIC_NAME).get(TOPIC_ARN_KEY) - def test_publish_to_target_plain(self): - hook = SnsHook(aws_conn_id="aws_default") + def test_get_conn_returns_a_boto3_connection(self, hook): + assert hook.get_conn() is not None - message = MESSAGE - topic_name = "test-topic" - target = hook.get_conn().create_topic(Name=topic_name).get("TopicArn") + def test_publish_to_target_with_subject(self, hook, target): + response = hook.publish_to_target(target, MESSAGE, SUBJECT) - response = hook.publish_to_target(target, message) + assert MESSAGE_ID_KEY in response - assert "MessageId" in response + def test_publish_to_target_with_attributes(self, hook, target): + response = hook.publish_to_target(target, MESSAGE, message_attributes=VALID_ATTRIBUTES) - def test_publish_to_target_error(self): - hook = SnsHook(aws_conn_id="aws_default") + assert MESSAGE_ID_KEY in response - message = "Hello world" - topic_name = "test-topic" - target = hook.get_conn().create_topic(Name=topic_name).get("TopicArn") + def test_publish_to_target_plain(self, hook, target): + response = hook.publish_to_target(target, MESSAGE) - error_message = ( - r"Values in MessageAttributes must be one of bytes, str, int, float, or iterable; got .*" - ) - with pytest.raises(TypeError, match=error_message): - hook.publish_to_target( - target, - message, - message_attributes={ - "test-non-iterable": object(), - }, - ) + assert MESSAGE_ID_KEY in response - def test_publish_to_target_with_deduplication(self): - hook = SnsHook(aws_conn_id="aws_default") + def test_publish_to_target_error(self, hook, target): + with pytest.raises(TypeError, match=INVALID_ATTRIBUTES_MSG): + hook.publish_to_target(target, MESSAGE, message_attributes=INVALID_ATTRIBUTES) - message = MESSAGE - topic_name = TOPIC_NAME + ".fifo" - deduplication_id = "abc" - group_id = "a" - target = ( + def test_publish_to_target_with_deduplication(self, hook): + fifo_target = ( hook.get_conn() .create_topic( - Name=topic_name, + Name=f"{TOPIC_NAME}.fifo", Attributes={ "FifoTopic": "true", "ContentBasedDeduplication": "false", @@ -115,7 +95,63 @@ def test_publish_to_target_with_deduplication(self): ) response = hook.publish_to_target( - target, message, message_deduplication_id=deduplication_id, message_group_id=group_id + fifo_target, MESSAGE, message_deduplication_id=DEDUPE_ID, message_group_id=GROUP_ID + ) + assert MESSAGE_ID_KEY in response + + +@pytest.mark.asyncio +class TestAsyncSnsHook: + """The mock_aws decorator uses `moto` which does not currently support async SNS so we mock it manually.""" + + @pytest.fixture + def hook(self): + return SnsHook(aws_conn_id="aws_default") + + @pytest.fixture + def mock_async_client(self): + mock_client = mock.AsyncMock() + mock_client.publish.return_value = {MESSAGE_ID_KEY: "test-message-id"} + return mock_client + + @pytest.fixture + def mock_get_async_conn(self, mock_async_client): + with mock.patch.object(SnsHook, "get_async_conn") as mocked_conn: + mocked_conn.return_value = mock_async_client + mocked_conn.return_value.__aenter__.return_value = mock_async_client + yield mocked_conn + + async def test_get_async_conn(self, hook, mock_get_async_conn, mock_async_client): + # Test context manager access + async with await hook.get_async_conn() as async_conn: + assert async_conn is mock_async_client + + # Test direct access + async_conn = await hook.get_async_conn() + assert async_conn is mock_async_client + + async def test_apublish_to_target_with_subject(self, hook, mock_get_async_conn, mock_async_client): + response = await hook.apublish_to_target(TOPIC_ARN, MESSAGE, SUBJECT) + + assert MESSAGE_ID_KEY in response + + async def test_apublish_to_target_with_attributes(self, hook, mock_get_async_conn, mock_async_client): + response = await hook.apublish_to_target(TOPIC_ARN, MESSAGE, message_attributes=VALID_ATTRIBUTES) + + assert MESSAGE_ID_KEY in response + + async def test_publish_to_target_plain(self, hook, mock_get_async_conn, mock_async_client): + response = await hook.apublish_to_target(TOPIC_ARN, MESSAGE) + + assert MESSAGE_ID_KEY in response + + async def test_publish_to_target_error(self, hook, mock_get_async_conn, mock_async_client): + with pytest.raises(TypeError, match=INVALID_ATTRIBUTES_MSG): + await hook.apublish_to_target(TOPIC_ARN, MESSAGE, message_attributes=INVALID_ATTRIBUTES) + + async def test_apublish_to_target_with_deduplication(self, hook, mock_get_async_conn, mock_async_client): + response = await hook.apublish_to_target( + TOPIC_ARN, MESSAGE, message_deduplication_id=DEDUPE_ID, message_group_id=GROUP_ID ) - assert "MessageId" in response + assert MESSAGE_ID_KEY in response diff --git a/providers/amazon/tests/unit/amazon/aws/notifications/test_sns.py b/providers/amazon/tests/unit/amazon/aws/notifications/test_sns.py index 87f1bfb94cce4..b09098d0d4966 100644 --- a/providers/amazon/tests/unit/amazon/aws/notifications/test_sns.py +++ b/providers/amazon/tests/unit/amazon/aws/notifications/test_sns.py @@ -23,30 +23,44 @@ from airflow.providers.amazon.aws.notifications.sns import SnsNotifier, send_sns_notification from airflow.utils.types import NOTSET -PARAM_DEFAULT_VALUE = pytest.param(NOTSET, id="default-value") +PUBLISH_KWARGS = { + "target_arn": "arn:aws:sns:us-west-2:123456789098:TopicName", + "message": "foo-bar", + "subject": "spam-egg", + "message_attributes": {}, +} class TestSnsNotifier: def test_class_and_notifier_are_same(self): assert send_sns_notification is SnsNotifier - @pytest.mark.parametrize("aws_conn_id", ["aws_test_conn_id", None, PARAM_DEFAULT_VALUE]) - @pytest.mark.parametrize("region_name", ["eu-west-2", None, PARAM_DEFAULT_VALUE]) + @pytest.mark.parametrize( + "aws_conn_id", + [ + pytest.param("aws_test_conn_id", id="custom-conn"), + pytest.param(None, id="none-conn"), + pytest.param(NOTSET, id="default-value"), + ], + ) + @pytest.mark.parametrize( + "region_name", + [ + pytest.param("eu-west-2", id="custom-region"), + pytest.param(None, id="no-region"), + pytest.param(NOTSET, id="default-value"), + ], + ) def test_parameters_propagate_to_hook(self, aws_conn_id, region_name): """Test notifier attributes propagate to SnsHook.""" - publish_kwargs = { - "target_arn": "arn:aws:sns:us-west-2:123456789098:TopicName", - "message": "foo-bar", - "subject": "spam-egg", - "message_attributes": {}, - } + notifier_kwargs = {} if aws_conn_id is not NOTSET: notifier_kwargs["aws_conn_id"] = aws_conn_id if region_name is not NOTSET: notifier_kwargs["region_name"] = region_name - notifier = SnsNotifier(**notifier_kwargs, **publish_kwargs) + notifier = SnsNotifier(**notifier_kwargs, **PUBLISH_KWARGS) with mock.patch("airflow.providers.amazon.aws.notifications.sns.SnsHook") as mock_hook: hook = notifier.hook assert hook is notifier.hook, "Hook property not cached" @@ -57,7 +71,17 @@ def test_parameters_propagate_to_hook(self, aws_conn_id, region_name): # Basic check for notifier notifier.notify({}) - mock_hook.return_value.publish_to_target.assert_called_once_with(**publish_kwargs) + mock_hook.return_value.publish_to_target.assert_called_once_with(**PUBLISH_KWARGS) + + @pytest.mark.asyncio + async def test_async_notify(self): + notifier = SnsNotifier(**PUBLISH_KWARGS) + with mock.patch("airflow.providers.amazon.aws.notifications.sns.SnsHook") as mock_hook: + mock_hook.return_value.apublish_to_target = mock.AsyncMock() + + await notifier.async_notify({}) + + mock_hook.return_value.apublish_to_target.assert_called_once_with(**PUBLISH_KWARGS) def test_sns_notifier_templated(self, create_dag_without_db): notifier = SnsNotifier(