Skip to content
135 changes: 84 additions & 51 deletions providers/http/src/airflow/providers/http/hooks/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# under the License.
from __future__ import annotations

import copy
from typing import TYPE_CHECKING, Any, Callable
from urllib.parse import urlparse

Expand Down Expand Up @@ -48,37 +49,50 @@ def _url_from_endpoint(base_url: str | None, endpoint: str | None) -> str:
return (base_url or "") + (endpoint or "")


def _process_extra_options_from_connection(conn: Connection, extra_options: dict[str, Any]) -> dict:
extra = conn.extra_dejson
stream = extra.pop("stream", None)
cert = extra.pop("cert", None)
proxies = extra.pop("proxies", extra.pop("proxy", None))
timeout = extra.pop("timeout", None)
verify_ssl = extra.pop("verify", extra.pop("verify_ssl", None))
allow_redirects = extra.pop("allow_redirects", None)
max_redirects = extra.pop("max_redirects", None)
trust_env = extra.pop("trust_env", None)
check_response = extra.pop("check_response", None)

if stream is not None and "stream" not in extra_options:
extra_options["stream"] = stream
if cert is not None and "cert" not in extra_options:
extra_options["cert"] = cert
if proxies is not None and "proxy" not in extra_options:
extra_options["proxy"] = proxies
if timeout is not None and "timeout" not in extra_options:
extra_options["timeout"] = timeout
if verify_ssl is not None and "verify_ssl" not in extra_options:
extra_options["verify_ssl"] = verify_ssl
if allow_redirects is not None and "allow_redirects" not in extra_options:
extra_options["allow_redirects"] = allow_redirects
if max_redirects is not None and "max_redirects" not in extra_options:
extra_options["max_redirects"] = max_redirects
if trust_env is not None and "trust_env" not in extra_options:
extra_options["trust_env"] = trust_env
if check_response is not None and "check_response" not in extra_options:
extra_options["check_response"] = check_response
return extra
def _process_extra_options_from_connection(
conn: Connection, extra_options: dict[str, Any]
) -> tuple[dict[str, Any], dict[str, Any]]:
"""
Return the updated extra options from the connection, as well as those passed.

:param conn: The HTTP Connection object passed to the Hook
:param extra_options: Use-defined extra options
:return: (tuple)
"""
# Copy, to prevent changing conn.extra_dejson and extra_options
conn_extra_options: dict = copy.copy(conn.extra_dejson)
passed_extra_options: dict = copy.copy(extra_options)

stream = conn_extra_options.pop("stream", None)
cert = conn_extra_options.pop("cert", None)
proxies = conn_extra_options.pop("proxies", conn_extra_options.pop("proxy", None))
timeout = conn_extra_options.pop("timeout", None)
verify_ssl = conn_extra_options.pop("verify", conn_extra_options.pop("verify_ssl", None))
allow_redirects = conn_extra_options.pop("allow_redirects", None)
max_redirects = conn_extra_options.pop("max_redirects", None)
trust_env = conn_extra_options.pop("trust_env", None)
check_response = conn_extra_options.pop("check_response", None)

if stream is not None and "stream" not in passed_extra_options:
passed_extra_options["stream"] = stream
if cert is not None and "cert" not in passed_extra_options:
passed_extra_options["cert"] = cert
if proxies is not None and "proxy" not in passed_extra_options:
passed_extra_options["proxy"] = proxies
if timeout is not None and "timeout" not in passed_extra_options:
passed_extra_options["timeout"] = timeout
if verify_ssl is not None and "verify_ssl" not in passed_extra_options:
passed_extra_options["verify_ssl"] = verify_ssl
if allow_redirects is not None and "allow_redirects" not in passed_extra_options:
passed_extra_options["allow_redirects"] = allow_redirects
if max_redirects is not None and "max_redirects" not in passed_extra_options:
passed_extra_options["max_redirects"] = max_redirects
if trust_env is not None and "trust_env" not in passed_extra_options:
passed_extra_options["trust_env"] = trust_env
if check_response is not None and "check_response" not in passed_extra_options:
passed_extra_options["check_response"] = check_response

return conn_extra_options, passed_extra_options


class HttpHook(BaseHook):
Expand All @@ -96,7 +110,6 @@ class HttpHook(BaseHook):
:param tcp_keep_alive_count: The TCP Keep Alive count parameter (corresponds to ``socket.TCP_KEEPCNT``)
:param tcp_keep_alive_interval: The TCP Keep Alive interval parameter (corresponds to
``socket.TCP_KEEPINTVL``)
:param auth_args: extra arguments used to initialize the auth_type if different than default HTTPBasicAuth
"""

conn_name_attr = "http_conn_id"
Expand Down Expand Up @@ -135,6 +148,8 @@ def __init__(
else:
self.keep_alive_adapter = None

self.merged_extra: dict = {}

@property
def auth_type(self):
return self._auth_type or HTTPBasicAuth
Expand All @@ -159,8 +174,14 @@ def get_conn(
connection = self.get_connection(self.http_conn_id)
self._set_base_url(connection)
session = self._configure_session_from_auth(session, connection)

# Since get_conn can be called outside of run, we'll check this again
extra_options = extra_options or {}

if connection.extra or extra_options:
# These are being passed from to _configure_session_from_extra, no manipulation has been done yet
session = self._configure_session_from_extra(session, connection, extra_options)

session = self._configure_session_from_mount_adapters(session)
if self.default_headers:
session.headers.update(self.default_headers)
Expand Down Expand Up @@ -194,21 +215,33 @@ def _extract_auth(self, connection: Connection) -> Any | None:
return None

def _configure_session_from_extra(
self, session: Session, connection: Connection, extra_options: dict[str, Any] | None = None
self, session: Session, connection: Connection, extra_options: dict[str, Any]
) -> Session:
if extra_options is None:
extra_options = {}
headers = _process_extra_options_from_connection(connection, extra_options)
session.proxies = extra_options.pop("proxies", extra_options.pop("proxy", {}))
session.stream = extra_options.pop("stream", False)
session.verify = extra_options.pop("verify", extra_options.pop("verify_ssl", True))
session.cert = extra_options.pop("cert", None)
session.max_redirects = extra_options.pop("max_redirects", DEFAULT_REDIRECT_LIMIT)
session.trust_env = extra_options.pop("trust_env", True)
"""
Configure the session using both the extra field from the Connection and passed in extra_options.

:param session: (Session)
:param connection: HTTP Connection passed into Hook
:param extra_options: (dict)
:return: (Session)
"""
# This is going to update self.merged_extra, which will be used below
conn_extra_options, self.merged_extra = _process_extra_options_from_connection(
connection, extra_options
)

session.proxies = self.merged_extra.get("proxies", self.merged_extra.get("proxy", {}))
session.stream = self.merged_extra.get("stream", False)
session.verify = self.merged_extra.get("verify", self.merged_extra.get("verify_ssl", True))
session.cert = self.merged_extra.get("cert", None)
session.max_redirects = self.merged_extra.get("max_redirects", DEFAULT_REDIRECT_LIMIT)
session.trust_env = self.merged_extra.get("trust_env", True)

try:
session.headers.update(headers)
session.headers.update(conn_extra_options)
except TypeError:
self.log.warning("Connection to %s has invalid extra field.", connection.host)

return session

def _configure_session_from_mount_adapters(self, session: Session) -> Session:
Expand Down Expand Up @@ -245,9 +278,7 @@ def run(
For example, ``run(json=obj)`` is passed as ``requests.Request(json=obj)``
"""
extra_options = extra_options or {}

session = self.get_conn(headers, extra_options)

session = self.get_conn(headers, extra_options) # This sets self.merged_extra, which is used later
url = self.url_from_endpoint(endpoint)

if self.method == "GET":
Expand All @@ -262,7 +293,9 @@ def run(

prepped_request = session.prepare_request(req)
self.log.debug("Sending '%s' to url: %s", self.method, url)
return self.run_and_check(session, prepped_request, extra_options)

# This is referencing self.merged_extra, which is update by _process ...
return self.run_and_check(session, prepped_request, self.merged_extra)

def check_response(self, response: Response) -> None:
"""
Expand Down Expand Up @@ -294,8 +327,6 @@ def run_and_check(
i.e. ``{'check_response': False}`` to avoid checking raising exceptions on non 2XX
or 3XX status codes
"""
extra_options = extra_options or {}

settings = session.merge_environment_settings(
prepped_request.url,
proxies=session.proxies,
Expand Down Expand Up @@ -439,10 +470,12 @@ async def run(
if conn.login:
auth = self.auth_type(conn.login, conn.password)
if conn.extra:
extra = _process_extra_options_from_connection(conn=conn, extra_options=extra_options)
conn_extra_options, extra_options = _process_extra_options_from_connection(
conn=conn, extra_options=extra_options
)

try:
_headers.update(extra)
_headers.update(conn_extra_options)
except TypeError:
self.log.warning("Connection to %s has invalid extra field.", conn.host)
if headers:
Expand Down
9 changes: 6 additions & 3 deletions providers/http/tests/unit/http/hooks/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,9 +655,11 @@ def test_process_extra_options_from_connection(self):
}
)()

actual = _process_extra_options_from_connection(conn=conn, extra_options=extra_options)
actual_conn_extra, actual_merged_extra = _process_extra_options_from_connection(
conn=conn, extra_options=extra_options
)

assert extra_options == {
assert actual_merged_extra == {
"cert": "cert.crt",
"stream": True,
"proxy": proxy,
Expand All @@ -667,7 +669,8 @@ def test_process_extra_options_from_connection(self):
"max_redirects": 3,
"trust_env": False,
}
assert actual == {"bearer": "test"}
assert actual_conn_extra == {"bearer": "test"}
assert extra_options == {}


class TestHttpAsyncHook:
Expand Down