diff --git a/providers/redis/docs/index.rst b/providers/redis/docs/index.rst index 63b7745fcd5cd..e195e6bfa9992 100644 --- a/providers/redis/docs/index.rst +++ b/providers/redis/docs/index.rst @@ -36,6 +36,7 @@ Connection types Logging + Triggers .. toctree:: :hidden: diff --git a/providers/redis/docs/triggers.rst b/providers/redis/docs/triggers.rst new file mode 100644 index 0000000000000..96274b737a965 --- /dev/null +++ b/providers/redis/docs/triggers.rst @@ -0,0 +1,29 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Redis Triggers +============== + +.. _howto/triggers:AwaitMessageTrigger: + +AwaitMessageTrigger +------------------- + +The ``AwaitMessageTrigger`` is a trigger that asynchronously waits for a message to arrive on one or more specified Redis PubSub channels. + +For parameter definitions take a look at :class:`~airflow.providers.redis.triggers.redis_await_message.AwaitMessageTrigger`. diff --git a/providers/redis/provider.yaml b/providers/redis/provider.yaml index b509a41492b77..c22ce2e2aac32 100644 --- a/providers/redis/provider.yaml +++ b/providers/redis/provider.yaml @@ -73,6 +73,11 @@ sensors: - airflow.providers.redis.sensors.redis_key - airflow.providers.redis.sensors.redis_pub_sub +triggers: + - integration-name: Redis + python-modules: + - airflow.providers.redis.triggers.redis_await_message + hooks: - integration-name: Redis python-modules: diff --git a/providers/redis/src/airflow/providers/redis/get_provider_info.py b/providers/redis/src/airflow/providers/redis/get_provider_info.py index 1a4dfa8f5acf7..8ee4dd23637bc 100644 --- a/providers/redis/src/airflow/providers/redis/get_provider_info.py +++ b/providers/redis/src/airflow/providers/redis/get_provider_info.py @@ -49,6 +49,12 @@ def get_provider_info(): ], } ], + "triggers": [ + { + "integration-name": "Redis", + "python-modules": ["airflow.providers.redis.triggers.redis_await_message"], + } + ], "hooks": [{"integration-name": "Redis", "python-modules": ["airflow.providers.redis.hooks.redis"]}], "connection-types": [ {"hook-class-name": "airflow.providers.redis.hooks.redis.RedisHook", "connection-type": "redis"} diff --git a/providers/redis/src/airflow/providers/redis/triggers/__init__.py b/providers/redis/src/airflow/providers/redis/triggers/__init__.py new file mode 100644 index 0000000000000..217e5db960782 --- /dev/null +++ b/providers/redis/src/airflow/providers/redis/triggers/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/providers/redis/src/airflow/providers/redis/triggers/redis_await_message.py b/providers/redis/src/airflow/providers/redis/triggers/redis_await_message.py new file mode 100644 index 0000000000000..8292eb1b84891 --- /dev/null +++ b/providers/redis/src/airflow/providers/redis/triggers/redis_await_message.py @@ -0,0 +1,72 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import asyncio +from typing import Any + +from asgiref.sync import sync_to_async + +from airflow.providers.redis.hooks.redis import RedisHook +from airflow.triggers.base import BaseTrigger, TriggerEvent + + +class AwaitMessageTrigger(BaseTrigger): + """ + A trigger that waits for a message matching specific criteria to arrive in Redis. + + The behavior of this trigger is as follows: + - poll the Redis pubsub for a message, if no message returned, sleep + + :param channels: The channels that should be searched for messages + :param redis_conn_id: The connection object to use, defaults to "redis_default" + :param poll_interval: How long the trigger should sleep after reaching the end of the Redis log + (seconds), defaults to 60 + """ + + def __init__( + self, + channels: list[str] | str, + redis_conn_id: str = "redis_default", + poll_interval: float = 60, + ) -> None: + self.channels = channels + self.redis_conn_id = redis_conn_id + self.poll_interval = poll_interval + + def serialize(self) -> tuple[str, dict[str, Any]]: + return ( + "airflow.providers.redis.triggers.redis_await_message.AwaitMessageTrigger", + { + "channels": self.channels, + "redis_conn_id": self.redis_conn_id, + "poll_interval": self.poll_interval, + }, + ) + + async def run(self): + hook = RedisHook(redis_conn_id=self.redis_conn_id).get_conn().pubsub() + hook.subscribe(self.channels) + + async_get_message = sync_to_async(hook.get_message) + while True: + message = await async_get_message() + + if message and message["type"] == "message": + yield TriggerEvent(message) + break + await asyncio.sleep(self.poll_interval) diff --git a/providers/redis/tests/unit/redis/triggers/__init__.py b/providers/redis/tests/unit/redis/triggers/__init__.py new file mode 100644 index 0000000000000..217e5db960782 --- /dev/null +++ b/providers/redis/tests/unit/redis/triggers/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/providers/redis/tests/unit/redis/triggers/test_redis_await_message.py b/providers/redis/tests/unit/redis/triggers/test_redis_await_message.py new file mode 100644 index 0000000000000..193fdb107fed6 --- /dev/null +++ b/providers/redis/tests/unit/redis/triggers/test_redis_await_message.py @@ -0,0 +1,90 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import asyncio +from unittest.mock import patch + +import pytest + +from airflow.providers.redis.triggers.redis_await_message import AwaitMessageTrigger + + +class TestAwaitMessageTrigger: + def test_trigger_serialization(self): + trigger = AwaitMessageTrigger( + channels=["test_channel"], + redis_conn_id="redis_default", + poll_interval=30, + ) + + assert isinstance(trigger, AwaitMessageTrigger) + + classpath, kwargs = trigger.serialize() + + assert classpath == "airflow.providers.redis.triggers.redis_await_message.AwaitMessageTrigger" + assert kwargs == dict( + channels=["test_channel"], + redis_conn_id="redis_default", + poll_interval=30, + ) + + @patch("airflow.providers.redis.hooks.redis.RedisHook.get_conn") + @pytest.mark.asyncio + async def test_trigger_run_succeed(self, mock_redis_conn): + trigger = AwaitMessageTrigger( + channels="test", + redis_conn_id="redis_default", + poll_interval=0.0001, + ) + + mock_redis_conn().pubsub().get_message.return_value = { + "type": "message", + "channel": "test", + "data": "d1", + } + + trigger_gen = trigger.run() + task = asyncio.create_task(trigger_gen.__anext__()) + event = await task + assert task.done() is True + assert event.payload["data"] == "d1" + assert event.payload["channel"] == "test" + asyncio.get_event_loop().stop() + + @patch("airflow.providers.redis.hooks.redis.RedisHook.get_conn") + @pytest.mark.asyncio + async def test_trigger_run_fail(self, mock_redis_conn): + trigger = AwaitMessageTrigger( + channels="test", + redis_conn_id="redis_default", + poll_interval=0.01, + ) + + mock_redis_conn().pubsub().get_message.return_value = { + "type": "subscribe", + "channel": "test", + "data": "d1", + } + + trigger_gen = trigger.run() + task = asyncio.create_task(trigger_gen.__anext__()) + await asyncio.sleep(1.0) + assert task.done() is False + task.cancel() + asyncio.get_event_loop().stop()