Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve AWS SQS Sensor (#16880) #16904

Merged
merged 6 commits into from
Aug 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 113 additions & 22 deletions airflow/providers/amazon/aws/sensors/sqs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@
# specific language governing permissions and limitations
# under the License.
"""Reads and then deletes the message from SQS queue"""
from typing import Optional
import json
from typing import Any, Optional

from jsonpath_ng import parse
from typing_extensions import Literal

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.sqs import SQSHook
Expand All @@ -37,9 +41,26 @@ class SQSSensor(BaseSensorOperator):
:type max_messages: int
:param wait_time_seconds: The time in seconds to wait for receiving messages (default: 1 second)
:type wait_time_seconds: int
:param visibility_timeout: Visibility timeout, a period of time during which
Amazon SQS prevents other consumers from receiving and processing the message.
:type visibility_timeout: Optional[Int]
:param message_filtering: Specified how received messages should be filtered. Supported options are:
`None` (no filtering, default), `'literal'` (message Body literal match) or `'jsonpath'`
(message Body filtered using a JSONPath expression).
You may add further methods by overriding the relevant class methods.
:type message_filtering: Optional[Literal["literal", "jsonpath"]]
:param message_filtering_match_values: Optional value/s for the message filter to match on.
For example, with literal matching, if a message body matches any of the specified values
then it is included. For JSONPath matching, the result of the JSONPath expression is used
and may match any of the specified values.
:type message_filtering_match_values: Any
:param message_filtering_config: Additional configuration to pass to the message filter.
For example with JSONPath filtering you can pass a JSONPath expression string here,
such as `'foo[*].baz'`. Messages with a Body which does not match are ignored.
:type message_filtering_config: Any
"""

template_fields = ('sqs_queue', 'max_messages')
template_fields = ('sqs_queue', 'max_messages', 'message_filtering_config')

def __init__(
self,
Expand All @@ -48,13 +69,32 @@ def __init__(
aws_conn_id: str = 'aws_default',
max_messages: int = 5,
wait_time_seconds: int = 1,
visibility_timeout: Optional[int] = None,
message_filtering: Optional[Literal["literal", "jsonpath"]] = None,
message_filtering_match_values: Any = None,
message_filtering_config: Any = None,
**kwargs,
):
super().__init__(**kwargs)
self.sqs_queue = sqs_queue
self.aws_conn_id = aws_conn_id
self.max_messages = max_messages
self.wait_time_seconds = wait_time_seconds
self.visibility_timeout = visibility_timeout

self.message_filtering = message_filtering

if message_filtering_match_values is not None:
if not isinstance(message_filtering_match_values, set):
message_filtering_match_values = set(message_filtering_match_values)
self.message_filtering_match_values = message_filtering_match_values
uranusjr marked this conversation as resolved.
Show resolved Hide resolved

if self.message_filtering == 'literal':
if self.message_filtering_match_values is None:
raise TypeError('message_filtering_match_values must be specified for literal matching')

self.message_filtering_config = message_filtering_config

self.hook: Optional[SQSHook] = None

def poke(self, context):
Expand All @@ -69,31 +109,48 @@ def poke(self, context):

self.log.info('SQSSensor checking for message on queue: %s', self.sqs_queue)

messages = sqs_conn.receive_message(
QueueUrl=self.sqs_queue,
MaxNumberOfMessages=self.max_messages,
WaitTimeSeconds=self.wait_time_seconds,
)
receive_message_kwargs = {
'QueueUrl': self.sqs_queue,
'MaxNumberOfMessages': self.max_messages,
'WaitTimeSeconds': self.wait_time_seconds,
}
if self.visibility_timeout is not None:
receive_message_kwargs['VisibilityTimeout'] = self.visibility_timeout

self.log.info("received message %s", str(messages))
response = sqs_conn.receive_message(**receive_message_kwargs)

if 'Messages' in messages and messages['Messages']:
entries = [
{'Id': message['MessageId'], 'ReceiptHandle': message['ReceiptHandle']}
for message in messages['Messages']
]
if "Messages" not in response:
return False

result = sqs_conn.delete_message_batch(QueueUrl=self.sqs_queue, Entries=entries)
messages = response['Messages']
num_messages = len(messages)
self.log.info("Received %d messages", num_messages)

if 'Successful' in result:
context['ti'].xcom_push(key='messages', value=messages)
return True
else:
raise AirflowException(
'Delete SQS Messages failed ' + str(result) + ' for messages ' + str(messages)
)
if not num_messages:
return False

return False
if self.message_filtering:
messages = self.filter_messages(messages)
num_messages = len(messages)
self.log.info("There are %d messages left after filtering", num_messages)

if not num_messages:
return False

self.log.info("Deleting %d messages", num_messages)

entries = [
{'Id': message['MessageId'], 'ReceiptHandle': message['ReceiptHandle']} for message in messages
]
response = sqs_conn.delete_message_batch(QueueUrl=self.sqs_queue, Entries=entries)

if 'Successful' in response:
context['ti'].xcom_push(key='messages', value=messages)
return True
else:
raise AirflowException(
'Delete SQS Messages failed ' + str(response) + ' for messages ' + str(messages)
)

def get_hook(self) -> SQSHook:
"""Create and return an SQSHook"""
Expand All @@ -102,3 +159,37 @@ def get_hook(self) -> SQSHook:

self.hook = SQSHook(aws_conn_id=self.aws_conn_id)
return self.hook

def filter_messages(self, messages):
if self.message_filtering == 'literal':
return self.filter_messages_literal(messages)
if self.message_filtering == 'jsonpath':
return self.filter_messages_jsonpath(messages)
else:
raise NotImplementedError('Override this method to define custom filters')

def filter_messages_literal(self, messages):
filtered_messages = []
for message in messages:
if message['Body'] in self.message_filtering_match_values:
filtered_messages.append(message)
return filtered_messages

def filter_messages_jsonpath(self, messages):
jsonpath_expr = parse(self.message_filtering_config)
filtered_messages = []
for message in messages:
body = message['Body']
# Body is a string, deserialise to an object and then parse
body = json.loads(body)
results = jsonpath_expr.find(body)
if not results:
continue
if self.message_filtering_match_values is None:
filtered_messages.append(message)
continue
for result in results:
if result.value in self.message_filtering_match_values:
filtered_messages.append(message)
break
return filtered_messages
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def write_version(filename: str = os.path.join(*[my_dir, "airflow", "git_version
amazon = [
'boto3>=1.15.0,<1.18.0',
'watchtower~=1.0.6',
'jsonpath_ng>=1.5.3',
]
apache_beam = [
'apache-beam>=2.20.0',
Expand Down
178 changes: 178 additions & 0 deletions tests/providers/amazon/aws/sensors/test_sqs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# under the License.


import json
import unittest
from unittest import mock

Expand Down Expand Up @@ -107,3 +108,180 @@ def test_poke_receive_raise_exception(self, mock_conn):
self.sensor.poke(self.mock_context)

assert 'test exception' in ctx.value.args[0]

@mock.patch.object(SQSHook, 'get_conn')
def test_poke_visibility_timeout(self, mock_conn):
# Check without visibility_timeout parameter
self.sqs_hook.create_queue('test')
self.sqs_hook.send_message(queue_url='test', message_body='hello')

self.sensor.poke(self.mock_context)

calls_receive_message = [
mock.call().receive_message(QueueUrl='test', MaxNumberOfMessages=5, WaitTimeSeconds=1)
]
mock_conn.assert_has_calls(calls_receive_message)
# Check with visibility_timeout parameter
self.sensor = SQSSensor(
task_id='test_task2',
dag=self.dag,
sqs_queue='test',
aws_conn_id='aws_default',
visibility_timeout=42,
)
self.sensor.poke(self.mock_context)

calls_receive_message = [
mock.call().receive_message(
QueueUrl='test', MaxNumberOfMessages=5, WaitTimeSeconds=1, VisibilityTimeout=42
)
]
mock_conn.assert_has_calls(calls_receive_message)

@mock_sqs
def test_poke_message_invalid_filtering(self):
self.sqs_hook.create_queue('test')
self.sqs_hook.send_message(queue_url='test', message_body='hello')
sensor = SQSSensor(
task_id='test_task2',
dag=self.dag,
sqs_queue='test',
aws_conn_id='aws_default',
message_filtering='invalid_option',
)
with pytest.raises(NotImplementedError) as ctx:
sensor.poke(self.mock_context)
assert 'Override this method to define custom filters' in ctx.value.args[0]

@mock.patch.object(SQSHook, "get_conn")
def test_poke_message_filtering_literal_values(self, mock_conn):
self.sqs_hook.create_queue('test')
matching = [{"id": 11, "body": "a matching message"}]
non_matching = [{"id": 12, "body": "a non-matching message"}]
all = matching + non_matching

def mock_receive_message(**kwargs):
messages = []
for message in all:
messages.append(
{
'MessageId': message['id'],
'ReceiptHandle': 100 + message['id'],
'Body': message['body'],
}
)
return {'Messages': messages}

mock_conn.return_value.receive_message.side_effect = mock_receive_message

def mock_delete_message_batch(**kwargs):
return {'Successful'}

mock_conn.return_value.delete_message_batch.side_effect = mock_delete_message_batch

# Test that messages are filtered
self.sensor.message_filtering = 'literal'
self.sensor.message_filtering_match_values = ["a matching message"]
result = self.sensor.poke(self.mock_context)
assert result

# Test that only filtered messages are deleted
delete_entries = [{'Id': x['id'], 'ReceiptHandle': 100 + x['id']} for x in matching]
calls_delete_message_batch = [
mock.call().delete_message_batch(QueueUrl='test', Entries=delete_entries)
]
mock_conn.assert_has_calls(calls_delete_message_batch)

@mock.patch.object(SQSHook, "get_conn")
def test_poke_message_filtering_jsonpath(self, mock_conn):
self.sqs_hook.create_queue('test')
matching = [
{"id": 11, "key": {"matches": [1, 2]}},
{"id": 12, "key": {"matches": [3, 4, 5]}},
{"id": 13, "key": {"matches": [10]}},
]
non_matching = [
{"id": 14, "key": {"nope": [5, 6]}},
{"id": 15, "key": {"nope": [7, 8]}},
]
all = matching + non_matching

def mock_receive_message(**kwargs):
messages = []
for message in all:
messages.append(
{
'MessageId': message['id'],
'ReceiptHandle': 100 + message['id'],
'Body': json.dumps(message),
}
)
return {'Messages': messages}

mock_conn.return_value.receive_message.side_effect = mock_receive_message

def mock_delete_message_batch(**kwargs):
return {'Successful'}

mock_conn.return_value.delete_message_batch.side_effect = mock_delete_message_batch

# Test that messages are filtered
self.sensor.message_filtering = 'jsonpath'
self.sensor.message_filtering_config = 'key.matches[*]'
result = self.sensor.poke(self.mock_context)
assert result

# Test that only filtered messages are deleted
delete_entries = [{'Id': x['id'], 'ReceiptHandle': 100 + x['id']} for x in matching]
calls_delete_message_batch = [
mock.call().delete_message_batch(QueueUrl='test', Entries=delete_entries)
]
mock_conn.assert_has_calls(calls_delete_message_batch)

@mock.patch.object(SQSHook, "get_conn")
def test_poke_message_filtering_jsonpath_values(self, mock_conn):
self.sqs_hook.create_queue('test')
matching = [
{"id": 11, "key": {"matches": [1, 2]}},
{"id": 12, "key": {"matches": [1, 4, 5]}},
{"id": 13, "key": {"matches": [4, 5]}},
]
non_matching = [
{"id": 21, "key": {"matches": [10]}},
{"id": 22, "key": {"nope": [5, 6]}},
{"id": 23, "key": {"nope": [7, 8]}},
]
all = matching + non_matching

def mock_receive_message(**kwargs):
messages = []
for message in all:
messages.append(
{
'MessageId': message['id'],
'ReceiptHandle': 100 + message['id'],
'Body': json.dumps(message),
}
)
return {'Messages': messages}

mock_conn.return_value.receive_message.side_effect = mock_receive_message

def mock_delete_message_batch(**kwargs):
return {'Successful'}

mock_conn.return_value.delete_message_batch.side_effect = mock_delete_message_batch

# Test that messages are filtered
self.sensor.message_filtering = 'jsonpath'
self.sensor.message_filtering_config = 'key.matches[*]'
self.sensor.message_filtering_match_values = [1, 4]
result = self.sensor.poke(self.mock_context)
assert result

# Test that only filtered messages are deleted
delete_entries = [{'Id': x['id'], 'ReceiptHandle': 100 + x['id']} for x in matching]
calls_delete_message_batch = [
mock.call().delete_message_batch(QueueUrl='test', Entries=delete_entries)
]
mock_conn.assert_has_calls(calls_delete_message_batch)