diff --git a/providers/amazon/src/airflow/providers/amazon/aws/hooks/sns.py b/providers/amazon/src/airflow/providers/amazon/aws/hooks/sns.py index 1c9a99b72f6e6..00315fbddc370 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/sns.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/sns.py @@ -60,6 +60,8 @@ def publish_to_target( message: str, subject: str | None = None, message_attributes: dict | None = None, + message_deduplication_id: str | None = None, + message_group_id: str | None = None, ): """ Publish a message to a SNS topic or an endpoint. @@ -77,7 +79,10 @@ def publish_to_target( - str = String - int, float = Number - iterable = String.Array - + :param message_deduplication_id: Every message must have a unique message_deduplication_id. + This parameter applies only to FIFO (first-in-first-out) topics. + :param message_group_id: Tag that specifies that a message belongs to a specific message group. + This parameter applies only to FIFO (first-in-first-out) topics. """ publish_kwargs: dict[str, str | dict] = { "TargetArn": target_arn, @@ -88,6 +93,10 @@ def publish_to_target( # Construct args this way because boto3 distinguishes from missing args and those set to None if subject: publish_kwargs["Subject"] = subject + if message_deduplication_id: + publish_kwargs["MessageDeduplicationId"] = message_deduplication_id + if message_group_id: + publish_kwargs["MessageGroupId"] = message_group_id if message_attributes: publish_kwargs["MessageAttributes"] = { key: _get_message_attribute(val) for key, val in message_attributes.items() diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/sns.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/sns.py index c8a29355edbe2..5b13e58863a1d 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/sns.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/sns.py @@ -53,6 +53,10 @@ class SnsPublishOperator(AwsBaseOperator[SnsHook]): https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html :param botocore_config: Configuration dictionary (key-values) for botocore client. See: https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html + :param message_deduplication_id: Every message must have a unique message_deduplication_id. + This parameter applies only to FIFO (first-in-first-out) topics. + :param message_group_id: Tag that specifies that a message belongs to a specific message group. + This parameter applies only to FIFO (first-in-first-out) topics. """ aws_hook_class = SnsHook @@ -61,6 +65,8 @@ class SnsPublishOperator(AwsBaseOperator[SnsHook]): "message", "subject", "message_attributes", + "message_deduplication_id", + "message_group_id", ) template_fields_renderers = {"message_attributes": "json"} @@ -71,6 +77,8 @@ def __init__( message: str, subject: str | None = None, message_attributes: dict | None = None, + message_deduplication_id: str | None = None, + message_group_id: str | None = None, **kwargs, ): super().__init__(**kwargs) @@ -78,15 +86,19 @@ def __init__( self.message = message self.subject = subject self.message_attributes = message_attributes + self.message_deduplication_id = message_deduplication_id + self.message_group_id = message_group_id def execute(self, context: Context): self.log.info( - "Sending SNS notification to %s using %s:\nsubject=%s\nattributes=%s\nmessage=%s", + "Sending SNS notification to %s using %s:\nsubject=%s\nattributes=%s\nmessage=%s\nmessage_deduplication_id=%s\nmessage_group_id=%s", self.target_arn, self.aws_conn_id, self.subject, self.message_attributes, self.message, + self.message_deduplication_id, + self.message_group_id, ) return self.hook.publish_to_target( @@ -94,4 +106,6 @@ def execute(self, context: Context): message=self.message, subject=self.subject, message_attributes=self.message_attributes, + message_deduplication_id=self.message_deduplication_id, + message_group_id=self.message_group_id, ) diff --git a/providers/amazon/tests/unit/amazon/aws/hooks/test_sns.py b/providers/amazon/tests/unit/amazon/aws/hooks/test_sns.py index d204a56a5dc21..19043cd5f45fb 100644 --- a/providers/amazon/tests/unit/amazon/aws/hooks/test_sns.py +++ b/providers/amazon/tests/unit/amazon/aws/hooks/test_sns.py @@ -22,6 +22,10 @@ from airflow.providers.amazon.aws.hooks.sns import SnsHook +MESSAGE = "Hello world" +TOPIC_NAME = "test-topic" +SUBJECT = "test-subject" + @mock_aws class TestSnsHook: @@ -32,9 +36,9 @@ def test_get_conn_returns_a_boto3_connection(self): def test_publish_to_target_with_subject(self): hook = SnsHook(aws_conn_id="aws_default") - message = "Hello world" - topic_name = "test-topic" - subject = "test-subject" + message = MESSAGE + topic_name = TOPIC_NAME + subject = SUBJECT target = hook.get_conn().create_topic(Name=topic_name).get("TopicArn") response = hook.publish_to_target(target, message, subject) @@ -44,8 +48,8 @@ def test_publish_to_target_with_subject(self): def test_publish_to_target_with_attributes(self): hook = SnsHook(aws_conn_id="aws_default") - message = "Hello world" - topic_name = "test-topic" + message = MESSAGE + topic_name = TOPIC_NAME target = hook.get_conn().create_topic(Name=topic_name).get("TopicArn") response = hook.publish_to_target( @@ -64,7 +68,7 @@ def test_publish_to_target_with_attributes(self): def test_publish_to_target_plain(self): hook = SnsHook(aws_conn_id="aws_default") - message = "Hello world" + message = MESSAGE topic_name = "test-topic" target = hook.get_conn().create_topic(Name=topic_name).get("TopicArn") @@ -90,3 +94,28 @@ def test_publish_to_target_error(self): "test-non-iterable": object(), }, ) + + def test_publish_to_target_with_deduplication(self): + hook = SnsHook(aws_conn_id="aws_default") + + message = MESSAGE + topic_name = TOPIC_NAME + ".fifo" + deduplication_id = "abc" + group_id = "a" + target = ( + hook.get_conn() + .create_topic( + Name=topic_name, + Attributes={ + "FifoTopic": "true", + "ContentBasedDeduplication": "false", + }, + ) + .get("TopicArn") + ) + + response = hook.publish_to_target( + target, message, message_deduplication_id=deduplication_id, message_group_id=group_id + ) + + assert "MessageId" in response diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_sns.py b/providers/amazon/tests/unit/amazon/aws/operators/test_sns.py index 1f7f0315320f0..219501a7fd1fe 100644 --- a/providers/amazon/tests/unit/amazon/aws/operators/test_sns.py +++ b/providers/amazon/tests/unit/amazon/aws/operators/test_sns.py @@ -27,7 +27,7 @@ TASK_ID = "sns_publish_job" AWS_CONN_ID = "custom_aws_conn" -TARGET_ARN = "arn:aws:sns:eu-central-1:1234567890:test-topic" +TARGET_ARN = "test-topic.fifo" MESSAGE = "Message to send" SUBJECT = "Subject to send" MESSAGE_ATTRIBUTES = {"test-attribute": "Attribute to send"} @@ -57,6 +57,8 @@ def test_init(self): region_name="us-west-1", verify="/spam/egg.pem", botocore_config={"read_timeout": 42}, + message_deduplication_id="abc", + message_group_id="a", ) assert op.hook.aws_conn_id == AWS_CONN_ID assert op.hook._region_name == "us-west-1" @@ -65,11 +67,24 @@ def test_init(self): assert op.hook._config.read_timeout == 42 @mock.patch.object(SnsPublishOperator, "hook") - def test_execute(self, mocked_hook): + @pytest.mark.parametrize( + "message_deduplication_id_,message_group_id_", + [ + ("abc", "a"), + (None, None), + ("abc", None), + (None, "a"), + ], + ) + def test_execute(self, mocked_hook, message_deduplication_id_, message_group_id_): hook_response = {"MessageId": "foobar"} mocked_hook.publish_to_target.return_value = hook_response - op = SnsPublishOperator(**self.default_op_kwargs) + op = SnsPublishOperator( + **self.default_op_kwargs, + message_deduplication_id=message_deduplication_id_, + message_group_id=message_group_id_, + ) assert op.execute({}) == hook_response mocked_hook.publish_to_target.assert_called_once_with( @@ -77,8 +92,12 @@ def test_execute(self, mocked_hook): message_attributes=MESSAGE_ATTRIBUTES, subject=SUBJECT, target_arn=TARGET_ARN, + message_deduplication_id=message_deduplication_id_, + message_group_id=message_group_id_, ) def test_template_fields(self): - operator = SnsPublishOperator(**self.default_op_kwargs) + operator = SnsPublishOperator( + **self.default_op_kwargs, message_deduplication_id="abc", message_group_id="a" + ) validate_template_fields(operator)