Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 74 additions & 18 deletions providers/amazon/src/airflow/providers/amazon/aws/hooks/sns.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import json

from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.utils.helpers import prune_dict


def _get_message_attribute(o):
Expand All @@ -38,6 +39,33 @@ def _get_message_attribute(o):
)


def _build_publish_kwargs(
target_arn: str,
message: str,
subject: str | None = None,
message_attributes: dict | None = None,
message_deduplication_id: str | None = None,
message_group_id: str | None = None,
) -> dict[str, str | dict]:
publish_kwargs: dict[str, str | dict] = prune_dict(
{
"TargetArn": target_arn,
"MessageStructure": "json",
"Message": json.dumps({"default": message}),
"Subject": subject,
"MessageDeduplicationId": message_deduplication_id,
"MessageGroupId": message_group_id,
}
)

if message_attributes:
publish_kwargs["MessageAttributes"] = {
key: _get_message_attribute(val) for key, val in message_attributes.items()
}

return publish_kwargs


class SnsHook(AwsBaseHook):
"""
Interact with Amazon Simple Notification Service.
Expand Down Expand Up @@ -84,22 +112,50 @@ def publish_to_target(
:param message_group_id: Tag that specifies that a message belongs to a specific message group.
This parameter applies only to FIFO (first-in-first-out) topics.
"""
publish_kwargs: dict[str, str | dict] = {
"TargetArn": target_arn,
"MessageStructure": "json",
"Message": json.dumps({"default": message}),
}
return self.get_conn().publish(
**_build_publish_kwargs(
target_arn, message, subject, message_attributes, message_deduplication_id, message_group_id
)
)

# Construct args this way because boto3 distinguishes from missing args and those set to None
if subject:
publish_kwargs["Subject"] = subject
if message_deduplication_id:
publish_kwargs["MessageDeduplicationId"] = message_deduplication_id
if message_group_id:
publish_kwargs["MessageGroupId"] = message_group_id
if message_attributes:
publish_kwargs["MessageAttributes"] = {
key: _get_message_attribute(val) for key, val in message_attributes.items()
}

return self.get_conn().publish(**publish_kwargs)
async def apublish_to_target(
self,
target_arn: str,
message: str,
subject: str | None = None,
message_attributes: dict | None = None,
message_deduplication_id: str | None = None,
message_group_id: str | None = None,
):
"""
Publish a message to a SNS topic or an endpoint.

.. seealso::
- :external+boto3:py:meth:`SNS.Client.publish`

:param target_arn: either a TopicArn or an EndpointArn
:param message: the default message you want to send
:param subject: subject of message
:param message_attributes: additional attributes to publish for message filtering. This should be
a flat dict; the DataType to be sent depends on the type of the value:

- bytes = Binary
- str = String
- int, float = Number
- iterable = String.Array
:param message_deduplication_id: Every message must have a unique message_deduplication_id.
This parameter applies only to FIFO (first-in-first-out) topics.
:param message_group_id: Tag that specifies that a message belongs to a specific message group.
This parameter applies only to FIFO (first-in-first-out) topics.
"""
async with await self.get_async_conn() as async_client:
return await async_client.publish(
**_build_publish_kwargs(
target_arn,
message,
subject,
message_attributes,
message_deduplication_id,
message_group_id,
)
)
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from functools import cached_property

from airflow.providers.amazon.aws.hooks.sns import SnsHook
from airflow.providers.amazon.version_compat import AIRFLOW_V_3_1_PLUS
from airflow.providers.common.compat.notifier import BaseNotifier


Expand Down Expand Up @@ -60,8 +61,13 @@ def __init__(
subject: str | None = None,
message_attributes: dict | None = None,
region_name: str | None = None,
**kwargs,
):
super().__init__()
if AIRFLOW_V_3_1_PLUS:
# Support for passing context was added in 3.1.0
super().__init__(**kwargs)
else:
super().__init__()
self.aws_conn_id = aws_conn_id
self.region_name = region_name
self.target_arn = target_arn
Expand All @@ -83,5 +89,14 @@ def notify(self, context):
message_attributes=self.message_attributes,
)

async def async_notify(self, context):
"""Publish the notification message to Amazon SNS (async)."""
await self.hook.apublish_to_target(
target_arn=self.target_arn,
message=self.message,
subject=self.subject,
message_attributes=self.message_attributes,
)


send_sns_notification = SnsNotifier
168 changes: 102 additions & 66 deletions providers/amazon/tests/unit/amazon/aws/hooks/test_sns.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,95 +17,75 @@
# under the License.
from __future__ import annotations

from unittest import mock

import pytest
from moto import mock_aws

from airflow.providers.amazon.aws.hooks.sns import SnsHook

DEDUPE_ID = "test-dedupe-id"
GROUP_ID = "test-group-id"
MESSAGE = "Hello world"
TOPIC_NAME = "test-topic"
SUBJECT = "test-subject"
INVALID_ATTRIBUTES_MSG = r"Values in MessageAttributes must be one of bytes, str, int, float, or iterable"

TOPIC_NAME = "test-topic"
TOPIC_ARN = f"arn:aws:sns:us-east-1:123456789012:{TOPIC_NAME}"

@mock_aws
class TestSnsHook:
def test_get_conn_returns_a_boto3_connection(self):
hook = SnsHook(aws_conn_id="aws_default")
assert hook.get_conn() is not None

def test_publish_to_target_with_subject(self):
hook = SnsHook(aws_conn_id="aws_default")

message = MESSAGE
topic_name = TOPIC_NAME
subject = SUBJECT
target = hook.get_conn().create_topic(Name=topic_name).get("TopicArn")

response = hook.publish_to_target(target, message, subject)
INVALID_ATTRIBUTES = {"test-non-iterable": object()}
VALID_ATTRIBUTES = {
"test-string": "string-value",
"test-number": 123456,
"test-array": ["first", "second", "third"],
"test-binary": b"binary-value",
}

assert "MessageId" in response
MESSAGE_ID_KEY = "MessageId"
TOPIC_ARN_KEY = "TopicArn"

def test_publish_to_target_with_attributes(self):
hook = SnsHook(aws_conn_id="aws_default")

message = MESSAGE
topic_name = TOPIC_NAME
target = hook.get_conn().create_topic(Name=topic_name).get("TopicArn")
class TestSnsHook:
@pytest.fixture(autouse=True)
def setup_moto(self):
with mock_aws():
yield

response = hook.publish_to_target(
target,
message,
message_attributes={
"test-string": "string-value",
"test-number": 123456,
"test-array": ["first", "second", "third"],
"test-binary": b"binary-value",
},
)
@pytest.fixture
def hook(self):
return SnsHook(aws_conn_id="aws_default")

assert "MessageId" in response
@pytest.fixture
def target(self, hook):
return hook.get_conn().create_topic(Name=TOPIC_NAME).get(TOPIC_ARN_KEY)

def test_publish_to_target_plain(self):
hook = SnsHook(aws_conn_id="aws_default")
def test_get_conn_returns_a_boto3_connection(self, hook):
assert hook.get_conn() is not None

message = MESSAGE
topic_name = "test-topic"
target = hook.get_conn().create_topic(Name=topic_name).get("TopicArn")
def test_publish_to_target_with_subject(self, hook, target):
response = hook.publish_to_target(target, MESSAGE, SUBJECT)

response = hook.publish_to_target(target, message)
assert MESSAGE_ID_KEY in response

assert "MessageId" in response
def test_publish_to_target_with_attributes(self, hook, target):
response = hook.publish_to_target(target, MESSAGE, message_attributes=VALID_ATTRIBUTES)

def test_publish_to_target_error(self):
hook = SnsHook(aws_conn_id="aws_default")
assert MESSAGE_ID_KEY in response

message = "Hello world"
topic_name = "test-topic"
target = hook.get_conn().create_topic(Name=topic_name).get("TopicArn")
def test_publish_to_target_plain(self, hook, target):
response = hook.publish_to_target(target, MESSAGE)

error_message = (
r"Values in MessageAttributes must be one of bytes, str, int, float, or iterable; got .*"
)
with pytest.raises(TypeError, match=error_message):
hook.publish_to_target(
target,
message,
message_attributes={
"test-non-iterable": object(),
},
)
assert MESSAGE_ID_KEY in response

def test_publish_to_target_with_deduplication(self):
hook = SnsHook(aws_conn_id="aws_default")
def test_publish_to_target_error(self, hook, target):
with pytest.raises(TypeError, match=INVALID_ATTRIBUTES_MSG):
hook.publish_to_target(target, MESSAGE, message_attributes=INVALID_ATTRIBUTES)

message = MESSAGE
topic_name = TOPIC_NAME + ".fifo"
deduplication_id = "abc"
group_id = "a"
target = (
def test_publish_to_target_with_deduplication(self, hook):
fifo_target = (
hook.get_conn()
.create_topic(
Name=topic_name,
Name=f"{TOPIC_NAME}.fifo",
Attributes={
"FifoTopic": "true",
"ContentBasedDeduplication": "false",
Expand All @@ -115,7 +95,63 @@ def test_publish_to_target_with_deduplication(self):
)

response = hook.publish_to_target(
target, message, message_deduplication_id=deduplication_id, message_group_id=group_id
fifo_target, MESSAGE, message_deduplication_id=DEDUPE_ID, message_group_id=GROUP_ID
)
assert MESSAGE_ID_KEY in response


@pytest.mark.asyncio
class TestAsyncSnsHook:
"""The mock_aws decorator uses `moto` which does not currently support async SNS so we mock it manually."""

@pytest.fixture
def hook(self):
return SnsHook(aws_conn_id="aws_default")

@pytest.fixture
def mock_async_client(self):
mock_client = mock.AsyncMock()
mock_client.publish.return_value = {MESSAGE_ID_KEY: "test-message-id"}
return mock_client

@pytest.fixture
def mock_get_async_conn(self, mock_async_client):
with mock.patch.object(SnsHook, "get_async_conn") as mocked_conn:
mocked_conn.return_value = mock_async_client
mocked_conn.return_value.__aenter__.return_value = mock_async_client
yield mocked_conn

async def test_get_async_conn(self, hook, mock_get_async_conn, mock_async_client):
# Test context manager access
async with await hook.get_async_conn() as async_conn:
assert async_conn is mock_async_client

# Test direct access
async_conn = await hook.get_async_conn()
assert async_conn is mock_async_client

async def test_apublish_to_target_with_subject(self, hook, mock_get_async_conn, mock_async_client):
response = await hook.apublish_to_target(TOPIC_ARN, MESSAGE, SUBJECT)

assert MESSAGE_ID_KEY in response

async def test_apublish_to_target_with_attributes(self, hook, mock_get_async_conn, mock_async_client):
response = await hook.apublish_to_target(TOPIC_ARN, MESSAGE, message_attributes=VALID_ATTRIBUTES)

assert MESSAGE_ID_KEY in response

async def test_publish_to_target_plain(self, hook, mock_get_async_conn, mock_async_client):
response = await hook.apublish_to_target(TOPIC_ARN, MESSAGE)

assert MESSAGE_ID_KEY in response

async def test_publish_to_target_error(self, hook, mock_get_async_conn, mock_async_client):
with pytest.raises(TypeError, match=INVALID_ATTRIBUTES_MSG):
await hook.apublish_to_target(TOPIC_ARN, MESSAGE, message_attributes=INVALID_ATTRIBUTES)

async def test_apublish_to_target_with_deduplication(self, hook, mock_get_async_conn, mock_async_client):
response = await hook.apublish_to_target(
TOPIC_ARN, MESSAGE, message_deduplication_id=DEDUPE_ID, message_group_id=GROUP_ID
)

assert "MessageId" in response
assert MESSAGE_ID_KEY in response
Loading
Loading