diff --git a/providers/http/src/airflow/providers/http/operators/http.py b/providers/http/src/airflow/providers/http/operators/http.py index ee03180881b80..2c7f260f39e12 100644 --- a/providers/http/src/airflow/providers/http/operators/http.py +++ b/providers/http/src/airflow/providers/http/operators/http.py @@ -22,13 +22,14 @@ from collections.abc import Sequence from typing import TYPE_CHECKING, Any, Callable +from aiohttp import BasicAuth from requests import Response from airflow.configuration import conf from airflow.exceptions import AirflowException from airflow.hooks.base import BaseHook from airflow.models import BaseOperator -from airflow.providers.http.triggers.http import HttpTrigger +from airflow.providers.http.triggers.http import HttpTrigger, serialize_auth_type from airflow.utils.helpers import merge_dicts if TYPE_CHECKING: @@ -122,7 +123,7 @@ def __init__( request_kwargs: dict[str, Any] | None = None, http_conn_id: str = "http_default", log_response: bool = False, - auth_type: type[AuthBase] | None = None, + auth_type: type[AuthBase] | type[BasicAuth] | None = None, tcp_keep_alive: bool = True, tcp_keep_alive_idle: int = 120, tcp_keep_alive_count: int = 20, @@ -221,7 +222,7 @@ def execute_async(self, context: Context) -> None: self.defer( trigger=HttpTrigger( http_conn_id=self.http_conn_id, - auth_type=self.auth_type, + auth_type=serialize_auth_type(self._resolve_auth_type()), method=self.method, endpoint=self.endpoint, headers=self.headers, @@ -231,6 +232,27 @@ def execute_async(self, context: Context) -> None: method_name="execute_complete", ) + def _resolve_auth_type(self) -> type[AuthBase] | type[BasicAuth] | None: + """ + Resolve the authentication type for the HTTP request. + + If auth_type is not explicitly set, attempt to infer it from the connection configuration. + For connections with login/password, default to BasicAuth. + + :return: The resolved authentication type class, or None if no auth is needed. + """ + if self.auth_type is not None: + return self.auth_type + + try: + conn = BaseHook.get_connection(self.http_conn_id) + if conn.login or conn.password: + return BasicAuth + except Exception as e: + self.log.warning("Failed to resolve auth type from connection: %s", e) + + return None + def process_response(self, context: Context, response: Response | list[Response]) -> Any: """Process the response.""" from airflow.utils.operator_helpers import determine_kwargs @@ -291,7 +313,7 @@ def paginate_async( self.defer( trigger=HttpTrigger( http_conn_id=self.http_conn_id, - auth_type=self.auth_type, + auth_type=serialize_auth_type(self._resolve_auth_type()), method=self.method, **self._merge_next_page_parameters(next_page_params), ), diff --git a/providers/http/src/airflow/providers/http/triggers/http.py b/providers/http/src/airflow/providers/http/triggers/http.py index d25d3a55cfb5b..6c1d13b136334 100644 --- a/providers/http/src/airflow/providers/http/triggers/http.py +++ b/providers/http/src/airflow/providers/http/triggers/http.py @@ -20,6 +20,7 @@ import base64 import pickle from collections.abc import AsyncIterator +from importlib import import_module from typing import TYPE_CHECKING, Any import aiohttp @@ -35,6 +36,21 @@ from aiohttp.client_reqrep import ClientResponse +def serialize_auth_type(auth: str | type | None) -> str | None: + if auth is None: + return None + if isinstance(auth, str): + return auth + return f"{auth.__module__}.{auth.__qualname__}" + + +def deserialize_auth_type(path: str | None) -> type | None: + if path is None: + return None + module_path, cls_name = path.rsplit(".", 1) + return getattr(import_module(module_path), cls_name) + + class HttpTrigger(BaseTrigger): """ HttpTrigger run on the trigger worker. @@ -56,7 +72,7 @@ class HttpTrigger(BaseTrigger): def __init__( self, http_conn_id: str = "http_default", - auth_type: Any = None, + auth_type: str | None = None, method: str = "POST", endpoint: str | None = None, headers: dict[str, str] | None = None, @@ -66,7 +82,7 @@ def __init__( super().__init__() self.http_conn_id = http_conn_id self.method = method - self.auth_type = auth_type + self.auth_type = deserialize_auth_type(auth_type) self.endpoint = endpoint self.headers = headers self.data = data @@ -79,7 +95,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]: { "http_conn_id": self.http_conn_id, "method": self.method, - "auth_type": self.auth_type, + "auth_type": serialize_auth_type(self.auth_type), "endpoint": self.endpoint, "headers": self.headers, "data": self.data, diff --git a/providers/http/tests/unit/http/operators/test_http.py b/providers/http/tests/unit/http/operators/test_http.py index 82787cb54304c..31be860cef6c0 100644 --- a/providers/http/tests/unit/http/operators/test_http.py +++ b/providers/http/tests/unit/http/operators/test_http.py @@ -21,18 +21,21 @@ import contextlib import json import pickle +from types import SimpleNamespace from unittest import mock from unittest.mock import call, patch import pytest import tenacity +from aiohttp import BasicAuth from requests import Response from requests.models import RequestEncodingMixin from airflow.exceptions import AirflowException, TaskDeferred +from airflow.hooks import base from airflow.providers.http.hooks.http import HttpHook from airflow.providers.http.operators.http import HttpOperator -from airflow.providers.http.triggers.http import HttpTrigger +from airflow.providers.http.triggers.http import HttpTrigger, serialize_auth_type @mock.patch.dict("os.environ", AIRFLOW_CONN_HTTP_EXAMPLE="http://www.example.com") @@ -92,6 +95,7 @@ def test_filters_response(self, requests_mock): result = operator.execute({}) assert result == {"value": 5} + @pytest.mark.db_test def test_async_defer_successfully(self, requests_mock): operator = HttpOperator( task_id="test_HTTP_op", @@ -186,6 +190,7 @@ def pagination_function(response: Response) -> dict | None: assert result == [5, 10] + @pytest.mark.db_test def test_async_pagination(self, requests_mock): """ Test that the HttpOperator calls asynchronously and repetitively @@ -300,3 +305,53 @@ def pagination_function(response: Response) -> dict | None: ) assert mock_run_with_advanced_retry.call_count == 2 + + def _capture_defer(self, monkeypatch): + captured = {} + + def _fake_defer(self, *, trigger, method_name, **kwargs): + captured["trigger"] = trigger + captured["kwargs"] = kwargs + + monkeypatch.setattr(HttpOperator, "defer", _fake_defer) + return captured + + @pytest.mark.parametrize( + "login, password, auth_type, expect_cls", + [ + ("user", "password", None, BasicAuth), + (None, None, None, type(None)), + ("user", "password", BasicAuth, BasicAuth), + ], + ) + def test_auth_type_is_serialised_as_string(self, monkeypatch, login, password, auth_type, expect_cls): + monkeypatch.setattr( + base.BaseHook, "get_connection", lambda _cid: SimpleNamespace(login=login, password=password) + ) + captured = self._capture_defer(monkeypatch) + + HttpOperator(task_id="test_HTTP_op", deferrable=True, auth_type=auth_type).execute(context={}) + + trigger = captured["trigger"] + kwargs = captured["trigger"].serialize()[1] + + expected_str = serialize_auth_type(expect_cls) if expect_cls is not type(None) else None + assert kwargs["auth_type"] == expected_str + + assert trigger.auth_type == expect_cls or (trigger.auth_type is None and expect_cls is type(None)) + + def test_resolve_auth_type_variants(self, monkeypatch): + monkeypatch.setattr( + base.BaseHook, "get_connection", lambda _cid: SimpleNamespace(login="user", password="password") + ) + assert HttpOperator(task_id="test_HTTP_op_1")._resolve_auth_type() is BasicAuth + + class DummyAuth: + def __init__(self, *_, **__): ... + + assert HttpOperator(task_id="test_HTTP_op_2", auth_type=DummyAuth)._resolve_auth_type() is DummyAuth + + monkeypatch.setattr( + base.BaseHook, "get_connection", lambda _cid: SimpleNamespace(login=None, password=None) + ) + assert HttpOperator(task_id="test_HTTP_op_3")._resolve_auth_type() is None