From b280c3979dc6a81ecffedace9e3b5efd81103870 Mon Sep 17 00:00:00 2001 From: pratiksha rajendrabhai badheka Date: Wed, 18 Dec 2024 23:11:05 +0530 Subject: [PATCH 1/2] add MessageDeduplicationId support to AWS SqsPublishOperator --- .../airflow/providers/amazon/aws/hooks/sqs.py | 4 ++++ .../providers/amazon/aws/operators/sqs.py | 6 ++++++ .../tests/amazon/aws/operators/test_sqs.py | 19 ++++++++++++++++++- 3 files changed, 28 insertions(+), 1 deletion(-) diff --git a/providers/src/airflow/providers/amazon/aws/hooks/sqs.py b/providers/src/airflow/providers/amazon/aws/hooks/sqs.py index f384bd4d28f24..c8793104e4aa3 100644 --- a/providers/src/airflow/providers/amazon/aws/hooks/sqs.py +++ b/providers/src/airflow/providers/amazon/aws/hooks/sqs.py @@ -59,6 +59,7 @@ def send_message( delay_seconds: int = 0, message_attributes: dict | None = None, message_group_id: str | None = None, + message_deduplication_id: str | None = None, ) -> dict: """ Send message to the queue. @@ -71,6 +72,7 @@ def send_message( :param delay_seconds: seconds to delay the message :param message_attributes: additional attributes for the message (default: None) :param message_group_id: This applies only to FIFO (first-in-first-out) queues. (default: None) + :param message_deduplication_id: This applies only to FIFO (first-in-first-out) queues. :return: dict with the information about the message sent """ params = { @@ -81,5 +83,7 @@ def send_message( } if message_group_id: params["MessageGroupId"] = message_group_id + if message_deduplication_id: + params["MessageDeduplicationId"] = message_deduplication_id return self.get_conn().send_message(**params) diff --git a/providers/src/airflow/providers/amazon/aws/operators/sqs.py b/providers/src/airflow/providers/amazon/aws/operators/sqs.py index dc453d0e6d1e4..817e611327034 100644 --- a/providers/src/airflow/providers/amazon/aws/operators/sqs.py +++ b/providers/src/airflow/providers/amazon/aws/operators/sqs.py @@ -44,6 +44,8 @@ class SqsPublishOperator(AwsBaseOperator[SqsHook]): :param delay_seconds: message delay (templated) (default: 1 second) :param message_group_id: This parameter applies only to FIFO (first-in-first-out) queues. (default: None) For details of the attributes parameter see :py:meth:`botocore.client.SQS.send_message` + :param message_deduplication_id: This applies only to FIFO (first-in-first-out) queues. + For details of the attributes parameter see :py:meth:`botocore.client.SQS.send_message` :param aws_conn_id: The Airflow connection used for AWS credentials. If this is ``None`` or empty then the default boto3 behaviour is used. If running Airflow in a distributed manner and aws_conn_id is None or @@ -63,6 +65,7 @@ class SqsPublishOperator(AwsBaseOperator[SqsHook]): "delay_seconds", "message_attributes", "message_group_id", + "message_deduplication_id", ) template_fields_renderers = {"message_attributes": "json"} ui_color = "#6ad3fa" @@ -75,6 +78,7 @@ def __init__( message_attributes: dict | None = None, delay_seconds: int = 0, message_group_id: str | None = None, + message_deduplication_id: str | None = None, **kwargs, ): super().__init__(**kwargs) @@ -83,6 +87,7 @@ def __init__( self.delay_seconds = delay_seconds self.message_attributes = message_attributes or {} self.message_group_id = message_group_id + self.message_deduplication_id = message_deduplication_id def execute(self, context: Context) -> dict: """ @@ -98,6 +103,7 @@ def execute(self, context: Context) -> dict: delay_seconds=self.delay_seconds, message_attributes=self.message_attributes, message_group_id=self.message_group_id, + message_deduplication_id=self.message_deduplication_id, ) self.log.info("send_message result: %s", result) diff --git a/providers/tests/amazon/aws/operators/test_sqs.py b/providers/tests/amazon/aws/operators/test_sqs.py index 21231bfbb2b2b..20293fe801ec8 100644 --- a/providers/tests/amazon/aws/operators/test_sqs.py +++ b/providers/tests/amazon/aws/operators/test_sqs.py @@ -103,6 +103,20 @@ def test_execute_failure_fifo_queue(self, mocked_context): with pytest.raises(ClientError, match=error_message): op.execute(mocked_context) + @mock_aws + def test_deduplication_failure(self, mocked_context): + self.sqs_client.create_queue( + QueueName=FIFO_QUEUE_NAME, Attributes={"FifoQueue": "true", "ContentBasedDeduplication": "false"} + ) + + op = SqsPublishOperator(**self.default_op_kwargs, sqs_queue=FIFO_QUEUE_NAME, message_group_id="abc") + error_message = ( + r"An error occurred \(InvalidParameterValue\) when calling the SendMessage operation: " + r"The queue should either have ContentBasedDeduplication enabled or MessageDeduplicationId provided explicitly" + ) + with pytest.raises(ClientError, match=error_message): + op.execute(mocked_context) + @mock_aws def test_execute_success_fifo_queue(self, mocked_context): self.sqs_client.create_queue( @@ -124,6 +138,9 @@ def test_execute_success_fifo_queue(self, mocked_context): def test_template_fields(self): operator = SqsPublishOperator( - **self.default_op_kwargs, sqs_queue=FIFO_QUEUE_NAME, message_group_id="abc" + **self.default_op_kwargs, + sqs_queue=FIFO_QUEUE_NAME, + message_group_id="abc", + message_deduplication_id="abc", ) validate_template_fields(operator) From 3b5918b5399dfdf7ae008bd21365931fcb7e2c35 Mon Sep 17 00:00:00 2001 From: pratiksha rajendrabhai badheka Date: Sat, 21 Dec 2024 23:19:02 +0530 Subject: [PATCH 2/2] modified test --- providers/tests/amazon/aws/operators/test_sqs.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/providers/tests/amazon/aws/operators/test_sqs.py b/providers/tests/amazon/aws/operators/test_sqs.py index 20293fe801ec8..6c39f5fef72ee 100644 --- a/providers/tests/amazon/aws/operators/test_sqs.py +++ b/providers/tests/amazon/aws/operators/test_sqs.py @@ -124,17 +124,25 @@ def test_execute_success_fifo_queue(self, mocked_context): ) # Send SQS Message into the FIFO Queue - op = SqsPublishOperator(**self.default_op_kwargs, sqs_queue=FIFO_QUEUE_NAME, message_group_id="abc") + op = SqsPublishOperator( + **self.default_op_kwargs, + sqs_queue=FIFO_QUEUE_NAME, + message_group_id="abc", + message_deduplication_id="abc", + ) result = op.execute(mocked_context) assert "MD5OfMessageBody" in result assert "MessageId" in result # Validate message through moto - message = self.sqs_client.receive_message(QueueUrl=FIFO_QUEUE_URL, AttributeNames=["MessageGroupId"]) + message = self.sqs_client.receive_message( + QueueUrl=FIFO_QUEUE_URL, AttributeNames=["MessageGroupId", "MessageDeduplicationId"] + ) assert len(message["Messages"]) == 1 assert message["Messages"][0]["MessageId"] == result["MessageId"] assert message["Messages"][0]["Body"] == "hello" assert message["Messages"][0]["Attributes"]["MessageGroupId"] == "abc" + assert message["Messages"][0]["Attributes"]["MessageDeduplicationId"] == "abc" def test_template_fields(self): operator = SqsPublishOperator(