diff --git a/providers/http/docs/index.rst b/providers/http/docs/index.rst index 154bae2a411e0..c571cd538f49a 100644 --- a/providers/http/docs/index.rst +++ b/providers/http/docs/index.rst @@ -36,6 +36,7 @@ Connection types Operators + Triggers .. toctree:: :hidden: diff --git a/providers/http/docs/triggers.rst b/providers/http/docs/triggers.rst new file mode 100644 index 0000000000000..752e2cd957161 --- /dev/null +++ b/providers/http/docs/triggers.rst @@ -0,0 +1,143 @@ + + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +HTTP Event Trigger +================== + +.. _howto/trigger:HttpEventTrigger: + +The ``HttpEventTrigger`` is an event-based trigger that monitors whether responses +from an API meet the conditions set by the user in the ``response_check`` callable. + +It is designed for **Airflow 3.0+** to be used in combination with the ``AssetWatcher`` system, +enabling event-driven DAGs based on API responses. + +How It Works +------------ + +1. Sends requests to an API. +2. Uses the callable at ``response_check_path`` to evaluate the API response. +3. If the callable returns ``True``, a ``TriggerEvent`` is emitted. This will trigger DAGs using this ``AssetWatcher`` for scheduling. + +.. note:: + This trigger requires **Airflow >= 3.0** due to dependencies on ``AssetWatcher`` and event-driven scheduling infrastructure. + +Usage Example with AssetWatcher +------------------------------- + +Here's an example of using the HttpEventTrigger in an AssetWatcher to monitor the GitHub API for new Airflow releases. + +.. code-block:: python + + + import datetime + import os + + from asgiref.sync import sync_to_async + + from airflow.providers.http.triggers.http import HttpEventTrigger + from airflow.sdk import Asset, AssetWatcher, Variable, dag, task + + # This token must be generated through GitHub and added as an environment variable + token = os.getenv("GITHUB_TOKEN") + + headers = { + "Accept": "application/vnd.github+json", + "Authorization": f"Bearer {token}", + "X-GitHub-Api-Version": "2022-11-28", + } + + + async def check_github_api_response(response): + data = response.json() + release_id = str(data["id"]) + get_variable_sync = sync_to_async(Variable.get) + previous_release_id = await get_variable_sync(key="release_id_var", default=None) + if release_id == previous_release_id: + return False + release_name = data["name"] + release_html_url = data["html_url"] + set_variable_sync = sync_to_async(Variable.set) + await set_variable_sync(key="release_id_var", value=str(release_id)) + await set_variable_sync(key="release_name_var", value=release_name) + await set_variable_sync(key="release_html_url_var", value=release_html_url) + return True + + + trigger = HttpEventTrigger( + endpoint="repos/apache/airflow/releases/latest", + method="GET", + http_conn_id="http_default", # HTTP connection with https://api.github.com/ as the Host + headers=headers, + response_check_path="dags.check_airflow_releases.check_github_api_response", # Path to the check_github_api_response callable + ) + + asset = Asset( + "airflow_releases_asset", watchers=[AssetWatcher(name="airflow_releases_watcher", trigger=trigger)] + ) + + + @dag(start_date=datetime.datetime(2024, 10, 1), schedule=asset, catchup=False) + def check_airflow_releases(): + @task() + def print_airflow_release_info(): + release_name = Variable.get("release_name_var") + release_html_url = Variable.get("release_html_url_var") + print(f"{release_name} has been released. Check it out at {release_html_url}") + + print_airflow_release_info() + + + check_airflow_releases() + +Parameters +---------- + +``http_conn_id`` + http connection id that has the base API url i.e https://www.google.com/ and optional authentication credentials. + Default headers can also be specified in the Extra field in json format. + +``auth_type`` + The auth type for the service + +``method`` + the API method to be called + +``endpoint`` + Endpoint to be called, i.e. ``resource/v1/query?`` + +``headers`` + Additional headers to be passed through as a dict + +``data`` + Payload to be uploaded or request parameters + +``extra_options`` + Additional kwargs to pass when creating a request. + +``response_check_path`` + Path to callable that evaluates whether the API response passes the conditions set by the user to trigger DAGs + + +Important Notes +--------------- + +1. A ``response_check_path`` value is required. +2. The ``response_check_path`` must contain the path to an asynchronous callable. Synchronous callables will raise an exception. +3. This trigger does not automatically record the previous API response. +4. The previous response may have to be persisted manually though ``Variable.set()`` in the ``response_check_path`` callable to prevent the trigger from emitting events repeatedly for the same API response. diff --git a/providers/http/src/airflow/providers/http/triggers/http.py b/providers/http/src/airflow/providers/http/triggers/http.py index 3708e5e13ed69..bd29ac71f596f 100644 --- a/providers/http/src/airflow/providers/http/triggers/http.py +++ b/providers/http/src/airflow/providers/http/triggers/http.py @@ -18,20 +18,30 @@ import asyncio import base64 +import importlib +import inspect import pickle +import sys from collections.abc import AsyncIterator from importlib import import_module from typing import TYPE_CHECKING, Any import aiohttp import requests +from asgiref.sync import sync_to_async from requests.cookies import RequestsCookieJar from requests.structures import CaseInsensitiveDict from airflow.exceptions import AirflowException from airflow.providers.http.hooks.http import HttpAsyncHook +from airflow.providers.http.version_compat import AIRFLOW_V_3_0_PLUS from airflow.triggers.base import BaseTrigger, TriggerEvent +if AIRFLOW_V_3_0_PLUS: + from airflow.triggers.base import BaseEventTrigger +else: + from airflow.triggers.base import BaseTrigger as BaseEventTrigger # type: ignore + if TYPE_CHECKING: from aiohttp.client_reqrep import ClientResponse @@ -105,21 +115,9 @@ def serialize(self) -> tuple[str, dict[str, Any]]: async def run(self) -> AsyncIterator[TriggerEvent]: """Make a series of asynchronous http calls via a http hook.""" - hook = HttpAsyncHook( - method=self.method, - http_conn_id=self.http_conn_id, - auth_type=self.auth_type, - ) + hook = self._get_async_hook() try: - async with aiohttp.ClientSession() as session: - client_response = await hook.run( - session=session, - endpoint=self.endpoint, - data=self.data, - headers=self.headers, - extra_options=self.extra_options, - ) - response = await self._convert_response(client_response) + response = await self._get_response(hook) yield TriggerEvent( { "status": "success", @@ -129,6 +127,25 @@ async def run(self) -> AsyncIterator[TriggerEvent]: except Exception as e: yield TriggerEvent({"status": "error", "message": str(e)}) + def _get_async_hook(self) -> HttpAsyncHook: + return HttpAsyncHook( + method=self.method, + http_conn_id=self.http_conn_id, + auth_type=self.auth_type, + ) + + async def _get_response(self, hook): + async with aiohttp.ClientSession() as session: + client_response = await hook.run( + session=session, + endpoint=self.endpoint, + data=self.data, + headers=self.headers, + extra_options=self.extra_options, + ) + response = await self._convert_response(client_response) + return response + @staticmethod async def _convert_response(client_response: ClientResponse) -> requests.Response: """Convert aiohttp.client_reqrep.ClientResponse to requests.Response.""" @@ -219,3 +236,84 @@ def _get_async_hook(self) -> HttpAsyncHook: method=self.method, http_conn_id=self.http_conn_id, ) + + +class HttpEventTrigger(HttpTrigger, BaseEventTrigger): + """ + HttpEventTrigger for event-based DAG scheduling when the API response satisfies the response check. + + :param response_check_path: Path to the function that evaluates whether the API response + passes the conditions set by the user to fire the trigger. The method must be asynchronous. + :param http_conn_id: http connection id that has the base + API url i.e https://www.google.com/ and optional authentication credentials. Default + headers can also be specified in the Extra field in json format. + :param auth_type: The auth type for the service + :param method: The API method to be called + :param endpoint: Endpoint to be called, i.e. ``resource/v1/query?``. + :param headers: Additional headers to be passed through as a dict. + :param data: Payload to be uploaded or request parameters. + :param extra_options: Additional kwargs to pass when creating a request. + """ + + def __init__( + self, + response_check_path: str, + http_conn_id: str = "http_default", + auth_type: Any = None, + method: str = "GET", + endpoint: str | None = None, + headers: dict[str, str] | None = None, + data: dict[str, Any] | str | None = None, + extra_options: dict[str, Any] | None = None, + ): + super().__init__(http_conn_id, auth_type, method, endpoint, headers, data, extra_options) + self.response_check_path = response_check_path + + def serialize(self) -> tuple[str, dict[str, Any]]: + """Serialize HttpEventTrigger arguments and classpath.""" + return ( + self.__class__.__module__ + "." + self.__class__.__qualname__, + { + "http_conn_id": self.http_conn_id, + "method": self.method, + "auth_type": self.auth_type, + "endpoint": self.endpoint, + "headers": self.headers, + "data": self.data, + "extra_options": self.extra_options, + "response_check_path": self.response_check_path, + }, + ) + + async def run(self) -> AsyncIterator[TriggerEvent]: + """Make a series of asynchronous http calls via a http hook until the response passes the response check.""" + hook = super()._get_async_hook() + try: + while True: + response = await super()._get_response(hook) + if await self._run_response_check(response): + break + yield TriggerEvent( + { + "status": "success", + "response": base64.standard_b64encode(pickle.dumps(response)).decode("ascii"), + } + ) + except Exception as e: + self.log.error("status: error, message: %s", str(e)) + + async def _import_from_response_check_path(self): + """Import the response check callable from the path provided by the user.""" + module_path, func_name = self.response_check_path.rsplit(".", 1) + if module_path in sys.modules: + module = await sync_to_async(importlib.reload)(sys.modules[module_path]) + module = await sync_to_async(importlib.import_module)(module_path) + return getattr(module, func_name) + + async def _run_response_check(self, response) -> bool: + """Run the response_check callable provided by the user.""" + response_check = await self._import_from_response_check_path() + if not inspect.iscoroutinefunction(response_check): + raise AirflowException("The response_check callable is not asynchronous.") + check = await response_check(response) + return check diff --git a/providers/http/src/airflow/providers/http/version_compat.py b/providers/http/src/airflow/providers/http/version_compat.py index 974126e72ed4d..ef9f9d6c244c1 100644 --- a/providers/http/src/airflow/providers/http/version_compat.py +++ b/providers/http/src/airflow/providers/http/version_compat.py @@ -33,6 +33,7 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]: AIRFLOW_V_3_0_PLUS = get_base_airflow_version_tuple() >= (3, 0, 0) + AIRFLOW_V_3_1_PLUS: bool = get_base_airflow_version_tuple() >= (3, 1, 0) if AIRFLOW_V_3_1_PLUS: diff --git a/providers/http/tests/unit/http/triggers/test_http.py b/providers/http/tests/unit/http/triggers/test_http.py index e2e6a9c5a95c8..8313f69f65449 100644 --- a/providers/http/tests/unit/http/triggers/test_http.py +++ b/providers/http/tests/unit/http/triggers/test_http.py @@ -31,7 +31,7 @@ from yarl import URL from airflow.models import Connection -from airflow.providers.http.triggers.http import HttpSensorTrigger, HttpTrigger +from airflow.providers.http.triggers.http import HttpEventTrigger, HttpSensorTrigger, HttpTrigger from airflow.triggers.base import TriggerEvent HTTP_PATH = "airflow.providers.http.triggers.http.{}" @@ -42,6 +42,7 @@ TEST_HEADERS = {"Authorization": "Bearer test"} TEST_DATA = {"key": "value"} TEST_EXTRA_OPTIONS: dict[str, Any] = {} +TEST_RESPONSE_CHECK_PATH = "mock.path" @pytest.fixture @@ -69,6 +70,20 @@ def sensor_trigger(): ) +@pytest.fixture +def event_trigger(): + return HttpEventTrigger( + http_conn_id=TEST_CONN_ID, + auth_type=TEST_AUTH_TYPE, + method=TEST_METHOD, + endpoint=TEST_ENDPOINT, + headers=TEST_HEADERS, + data=TEST_DATA, + extra_options=TEST_EXTRA_OPTIONS, + response_check_path=TEST_RESPONSE_CHECK_PATH, + ) + + @pytest.fixture def client_response(): client_response = mock.AsyncMock(ClientResponse) @@ -192,3 +207,99 @@ def test_serialization(self, sensor_trigger): "extra_options": TEST_EXTRA_OPTIONS, "poke_interval": 5.0, } + + +class TestHttpEventTrigger: + @staticmethod + def _mock_run_result(result_to_mock): + f = Future() + f.set_result(result_to_mock) + return f + + def test_serialization(self, event_trigger): + """ + Asserts that the HttpEventTrigger correctly serializes its arguments + and classpath. + """ + classpath, kwargs = event_trigger.serialize() + assert classpath == "airflow.providers.http.triggers.http.HttpEventTrigger" + assert kwargs == { + "http_conn_id": TEST_CONN_ID, + "auth_type": TEST_AUTH_TYPE, + "method": TEST_METHOD, + "endpoint": TEST_ENDPOINT, + "headers": TEST_HEADERS, + "data": TEST_DATA, + "extra_options": TEST_EXTRA_OPTIONS, + "response_check_path": TEST_RESPONSE_CHECK_PATH, + } + + @pytest.mark.asyncio + @mock.patch(HTTP_PATH.format("HttpAsyncHook")) + async def test_trigger_on_success_yield_successfully(self, mock_hook, event_trigger, client_response): + """ + Tests the HttpEventTrigger only fires once the job execution reaches a successful state. + """ + mock_hook.return_value.run.return_value = self._mock_run_result(client_response) + event_trigger._run_response_check = mock.AsyncMock(side_effect=[False, True]) + response = await HttpEventTrigger._convert_response(client_response) + + generator = event_trigger.run() + actual = await generator.asend(None) + assert actual == TriggerEvent( + { + "status": "success", + "response": base64.standard_b64encode(pickle.dumps(response)).decode("ascii"), + } + ) + assert mock_hook.return_value.run.call_count == 2 + assert event_trigger._run_response_check.call_count == 2 + + @pytest.mark.asyncio + @mock.patch(HTTP_PATH.format("HttpAsyncHook")) + async def test_trigger_on_exception_logs_error_and_never_yields( + self, mock_hook, event_trigger, monkeypatch + ): + """ + Tests the HttpEventTrigger logs the appropriate message and does not yield a TriggerEvent when an exception is raised. + """ + mock_hook.return_value.run.side_effect = Exception("Test exception") + mock_logger = mock.Mock() + monkeypatch.setattr(type(event_trigger), "log", mock_logger) + + generator = event_trigger.run() + with pytest.raises(StopAsyncIteration): + await generator.asend(None) + + mock_logger.error.assert_called_once_with("status: error, message: %s", "Test exception") + + @pytest.mark.asyncio + async def test_convert_response(self, client_response): + """ + Assert convert aiohttp.client_reqrep.ClientResponse to requests.Response. + """ + response = await HttpEventTrigger._convert_response(client_response) + assert response.content == await client_response.read() + assert response.status_code == client_response.status + assert response.headers == CaseInsensitiveDict(client_response.headers) + assert response.url == str(client_response.url) + assert response.history == [HttpEventTrigger._convert_response(h) for h in client_response.history] + assert response.encoding == client_response.get_encoding() + assert response.reason == client_response.reason + assert dict(response.cookies) == dict(client_response.cookies) + + @pytest.mark.db_test + @pytest.mark.asyncio + @mock.patch("aiohttp.client.ClientSession.post") + async def test_trigger_on_post_with_data(self, mock_http_post, event_trigger): + """ + Test that HttpEventTrigger posts the correct payload when a request is made. + """ + generator = event_trigger.run() + with pytest.raises(StopAsyncIteration): + await generator.asend(None) + mock_http_post.assert_called_once() + _, kwargs = mock_http_post.call_args + assert kwargs["data"] == TEST_DATA + assert kwargs["json"] is None + assert kwargs["params"] is None