diff --git a/airflow/providers/airbyte/CHANGELOG.rst b/airflow/providers/airbyte/CHANGELOG.rst index cef7dda80708a..24491fb7decc0 100644 --- a/airflow/providers/airbyte/CHANGELOG.rst +++ b/airflow/providers/airbyte/CHANGELOG.rst @@ -19,6 +19,17 @@ Changelog --------- +2.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +The Airbyte Hook derives now from the new version of Hook in Http Provider. + +TODO: Add extra dependency to 2.0.0 version of http Provider. + + 1.0.0 ..... diff --git a/airflow/providers/airbyte/provider.yaml b/airflow/providers/airbyte/provider.yaml index 77b109f45058d..08b44be18c815 100644 --- a/airflow/providers/airbyte/provider.yaml +++ b/airflow/providers/airbyte/provider.yaml @@ -22,6 +22,7 @@ description: | `Airbyte `__ versions: + - 2.0.0 - 1.0.0 integrations: diff --git a/airflow/providers/apache/livy/CHANGELOG.rst b/airflow/providers/apache/livy/CHANGELOG.rst index d43fe14754f4c..aaa65d7d29c17 100644 --- a/airflow/providers/apache/livy/CHANGELOG.rst +++ b/airflow/providers/apache/livy/CHANGELOG.rst @@ -19,6 +19,14 @@ Changelog --------- +2.0.0 +..... + +The Livy Hook derives now from the new version of Http Hook in Http Provider. + +TODO: Add extra dependency to 2.0.0 version of http Provider. + + 1.1.0 ..... diff --git a/airflow/providers/apache/livy/hooks/livy.py b/airflow/providers/apache/livy/hooks/livy.py index 75d08af56883f..0c4d6e972da4d 100644 --- a/airflow/providers/apache/livy/hooks/livy.py +++ b/airflow/providers/apache/livy/hooks/livy.py @@ -21,7 +21,7 @@ from enum import Enum from typing import Any, Dict, List, Optional, Sequence, Union -import requests +import httpx from airflow.exceptions import AirflowException from airflow.providers.http.hooks.http import HttpHook @@ -75,19 +75,25 @@ def __init__( super().__init__(http_conn_id=livy_conn_id) self.extra_options = extra_options or {} - def get_conn(self, headers: Optional[Dict[str, Any]] = None) -> Any: + def get_conn( + self, headers: Optional[Dict[Any, Any]] = None, verify: bool = True, proxies=None, cert=None + ) -> httpx.Client: """ Returns http session for use with requests :param headers: additional headers to be passed through as a dictionary :type headers: dict - :return: requests session - :rtype: requests.Session + :param verify: whether to verify SSL during the connection (only use for testing) + :param proxies: A dictionary mapping proxy keys to proxy + :param cert: client An SSL certificate used by the requested host + to authenticate the client. Either a path to an SSL certificate file, or + two-tuple of (certificate file, key file), or a three-tuple of (certificate + file, key file, password). """ tmp_headers = self._def_headers.copy() # setting default headers if headers: tmp_headers.update(headers) - return super().get_conn(tmp_headers) + return super().get_conn(tmp_headers, verify, proxies, cert) def run_method( self, @@ -108,7 +114,7 @@ def run_method( :param headers: headers :type headers: dict :return: http response - :rtype: requests.Response + :rtype: httpx.Response """ if method not in ('GET', 'POST', 'PUT', 'DELETE', 'HEAD'): raise ValueError(f"Invalid http method '{method}'") @@ -142,7 +148,7 @@ def post_batch(self, *args: Any, **kwargs: Any) -> Any: try: response.raise_for_status() - except requests.exceptions.HTTPError as err: + except httpx.HTTPStatusError as err: raise AirflowException( "Could not submit batch. Status code: {}. Message: '{}'".format( err.response.status_code, err.response.text @@ -172,7 +178,7 @@ def get_batch(self, session_id: Union[int, str]) -> Any: try: response.raise_for_status() - except requests.exceptions.HTTPError as err: + except httpx.HTTPStatusError as err: self.log.warning("Got status code %d for session %d", err.response.status_code, session_id) raise AirflowException( f"Unable to fetch batch with id: {session_id}. Message: {err.response.text}" @@ -196,7 +202,7 @@ def get_batch_state(self, session_id: Union[int, str]) -> BatchState: try: response.raise_for_status() - except requests.exceptions.HTTPError as err: + except httpx.HTTPStatusError as err: self.log.warning("Got status code %d for session %d", err.response.status_code, session_id) raise AirflowException( f"Unable to fetch batch with id: {session_id}. Message: {err.response.text}" @@ -223,7 +229,7 @@ def delete_batch(self, session_id: Union[int, str]) -> Any: try: response.raise_for_status() - except requests.exceptions.HTTPError as err: + except httpx.HTTPStatusError as err: self.log.warning("Got status code %d for session %d", err.response.status_code, session_id) raise AirflowException( "Could not kill the batch with session id: {}. Message: {}".format( diff --git a/airflow/providers/apache/livy/provider.yaml b/airflow/providers/apache/livy/provider.yaml index 309c2f538aa1a..f6999fa4f7e5c 100644 --- a/airflow/providers/apache/livy/provider.yaml +++ b/airflow/providers/apache/livy/provider.yaml @@ -22,6 +22,7 @@ description: | `Apache Livy `__ versions: + - 2.0.0 - 1.1.0 - 1.0.1 - 1.0.0 diff --git a/airflow/providers/http/CHANGELOG.rst b/airflow/providers/http/CHANGELOG.rst index b4a0313336731..dd3f98d3786c8 100644 --- a/airflow/providers/http/CHANGELOG.rst +++ b/airflow/providers/http/CHANGELOG.rst @@ -19,6 +19,36 @@ Changelog --------- + +2.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +Due to licencing issues, the HTTP provider switched from ``requests`` to ``httpx`` library. +In case you use an authentication method different than default Basic Authenticaton, +you will need to change it to use the httpx-compatible one. + +HttpHook's run() method passes any kwargs passed to underlying httpx.Requests object rather than +to requests.Requests. They largely compatible (for example json kwarg is supported in both) but not fully. +The ``content`` and ``auth`` parameters are not supported in ``httpx``. + +The "verify", "proxie" and "cert" extra parameters + +The get_conn() method of HttpHook returns ``httpx.Client`` rather than ``requests.Session``. + +The get_conn() method of HttpHook has extra parameters: verify ,proxies and cert which are executed by the +run() method, so if your hook derives from the HttpHook it should be updated. + +The ``run_and_check`` method is gone. PreparedRequests are not supported in ``httpx`` and this method +used it. Instead, run method simply creates and executes the full request instance. + +While ``httpx`` does not support ``REQUESTS_CA_BUNDLE`` variable overriding ca crt, we backported +the requests behaviour to HTTPHook and it respects the variable if set. More details on that variable +can be found `here None: super().__init__() self.http_conn_id = http_conn_id self.method = method.upper() self.base_url: str = "" self._retry_obj: Callable[..., Any] - self.auth_type: Any = auth_type + self.auth_type: Type[Auth] = auth_type + self.extra_headers: Dict = {} + self.conn = None # headers may be passed through directly or in the "extra" field in the connection # definition - def get_conn(self, headers: Optional[Dict[Any, Any]] = None) -> requests.Session: + def get_conn( + self, headers: Optional[Dict[Any, Any]] = None, verify: bool = True, proxies=None, cert=None + ) -> httpx.Client: """ - Returns http session for use with requests + Returns httpx client for use with requests :param headers: additional headers to be passed through as a dictionary :type headers: dict + :param verify: whether to verify SSL during the connection (only use for testing) + :param proxies: A dictionary mapping proxy keys to proxy + :param cert: client An SSL certificate used by the requested host + to authenticate the client. Either a path to an SSL certificate file, or + two-tuple of (certificate file, key file), or a three-tuple of (certificate + file, key file, password). """ - session = requests.Session() + client = httpx.Client(verify=verify, proxies=proxies, cert=cert) if self.http_conn_id: conn = self.get_connection(self.http_conn_id) - if conn.host and "://" in conn.host: self.base_url = conn.host else: @@ -82,16 +93,20 @@ def get_conn(self, headers: Optional[Dict[Any, Any]] = None) -> requests.Session if conn.port: self.base_url = self.base_url + ":" + str(conn.port) if conn.login: - session.auth = self.auth_type(conn.login, conn.password) + # Note! This handles Basic Auth and DigestAuth and any other authentication that + # supports login/password in the constructor. + client.auth = self.auth_type(conn.login, conn.password) # noqa if conn.extra: try: - session.headers.update(conn.extra_dejson) + client.headers.update(conn.extra_dejson) except TypeError: self.log.warning('Connection to %s has invalid extra field.', conn.host) + # Hooks deriving from HTTP Hook might use it to query connection details + self.conn = conn if headers: - session.headers.update(headers) + client.headers.update(headers) - return session + return client def run( self, @@ -111,16 +126,24 @@ def run( :param headers: additional headers to be passed through as a dictionary :type headers: dict :param extra_options: additional options to be used when executing the request - i.e. {'check_response': False} to avoid checking raising exceptions on non - 2XX or 3XX status codes + i.e. ``{'check_response': False}`` to avoid checking raising exceptions on non + 2XX or 3XX status codes. The extra options can take the following keys: + verify, proxies, cert, stream, allow_redirects, timeout, check_response. + See ``httpx.Client`` for description of those parameters. :type extra_options: dict :param request_kwargs: Additional kwargs to pass when creating a request. - For example, ``run(json=obj)`` is passed as ``requests.Request(json=obj)`` + For example, ``run(json=obj)`` is passed as ``httpx.Request(json=obj)`` """ extra_options = extra_options or {} - session = self.get_conn(headers) - + verify = extra_options.get("verify", True) + if verify is True: + # Only use REQUESTS_CA_BUNDLE content if verify is set to True, + # otherwise use passed value as it can be string, or SSLcontext + verify = os.environ.get('REQUESTS_CA_BUNDLE', verify) + proxies = extra_options.get("proxies", None) + cert = extra_options.get("cert", None) + client = self.get_conn(headers, verify=verify, proxies=proxies, cert=cert) if self.base_url and not self.base_url.endswith('/') and endpoint and not endpoint.startswith('/'): url = self.base_url + '/' + endpoint else: @@ -128,19 +151,33 @@ def run( if self.method == 'GET': # GET uses params - req = requests.Request(self.method, url, params=data, headers=headers, **request_kwargs) + req = httpx.Request(self.method, url, params=data, headers=client.headers, **request_kwargs) elif self.method == 'HEAD': # HEAD doesn't use params - req = requests.Request(self.method, url, headers=headers, **request_kwargs) + req = httpx.Request(self.method, url, headers=client.headers, **request_kwargs) else: # Others use data - req = requests.Request(self.method, url, data=data, headers=headers, **request_kwargs) + req = httpx.Request(self.method, url, headers=client.headers, data=data, **request_kwargs) - prepped_request = session.prepare_request(req) + # Send the request + send_kwargs = { + "stream": extra_options.get("stream", False), + "allow_redirects": extra_options.get("allow_redirects", True), + "timeout": extra_options.get("timeout"), + } self.log.info("Sending '%s' to url: %s", self.method, url) - return self.run_and_check(session, prepped_request, extra_options) - def check_response(self, response: requests.Response) -> None: + try: + response = client.send(req, **send_kwargs) + if extra_options.get('check_response', True): + self.check_response(response) + return response + + except httpx.NetworkError as ex: + self.log.warning('%s Tenacity will retry to execute the operation', ex) + raise ex + + def check_response(self, response: httpx.Response) -> None: """ Checks the status code and raise an AirflowException exception on non 2XX or 3XX status codes @@ -150,57 +187,15 @@ def check_response(self, response: requests.Response) -> None: """ try: response.raise_for_status() - except requests.exceptions.HTTPError: - self.log.error("HTTP error: %s", response.reason) + except httpx.HTTPError: + phrase = 'Unknown' + try: + phrase = http.HTTPStatus(response.status_code).phrase + except ValueError: + pass + self.log.error("HTTP error: %s", phrase) self.log.error(response.text) - raise AirflowException(str(response.status_code) + ":" + response.reason) - - def run_and_check( - self, - session: requests.Session, - prepped_request: requests.PreparedRequest, - extra_options: Dict[Any, Any], - ) -> Any: - """ - Grabs extra options like timeout and actually runs the request, - checking for the result - - :param session: the session to be used to execute the request - :type session: requests.Session - :param prepped_request: the prepared request generated in run() - :type prepped_request: session.prepare_request - :param extra_options: additional options to be used when executing the request - i.e. {'check_response': False} to avoid checking raising exceptions on non 2XX - or 3XX status codes - :type extra_options: dict - """ - extra_options = extra_options or {} - - settings = session.merge_environment_settings( - prepped_request.url, - proxies=extra_options.get("proxies", {}), - stream=extra_options.get("stream", False), - verify=extra_options.get("verify"), - cert=extra_options.get("cert"), - ) - - # Send the request. - send_kwargs = { - "timeout": extra_options.get("timeout"), - "allow_redirects": extra_options.get("allow_redirects", True), - } - send_kwargs.update(settings) - - try: - response = session.send(prepped_request, **send_kwargs) - - if extra_options.get('check_response', True): - self.check_response(response) - return response - - except requests.exceptions.ConnectionError as ex: - self.log.warning('%s Tenacity will retry to execute the operation', ex) - raise ex + raise AirflowException(str(response.status_code) + ":" + response.text) def run_with_advanced_retry(self, _retry_args: Dict[Any, Any], *args: Any, **kwargs: Any) -> Any: """ diff --git a/airflow/providers/http/operators/http.py b/airflow/providers/http/operators/http.py index b6295185d8ba8..fa04cb56a3297 100644 --- a/airflow/providers/http/operators/http.py +++ b/airflow/providers/http/operators/http.py @@ -1,4 +1,3 @@ -# # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -15,9 +14,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + from typing import Any, Callable, Dict, Optional, Type -from requests.auth import AuthBase, HTTPBasicAuth +from httpx import Auth, BasicAuth from airflow.exceptions import AirflowException from airflow.models import BaseOperator @@ -86,7 +86,7 @@ def __init__( extra_options: Optional[Dict[str, Any]] = None, http_conn_id: str = 'http_default', log_response: bool = False, - auth_type: Type[AuthBase] = HTTPBasicAuth, + auth_type: Type[Auth] = BasicAuth, **kwargs: Any, ) -> None: super().__init__(**kwargs) diff --git a/airflow/providers/http/provider.yaml b/airflow/providers/http/provider.yaml index 79a095b0a4417..cf1145b5cb6be 100644 --- a/airflow/providers/http/provider.yaml +++ b/airflow/providers/http/provider.yaml @@ -22,6 +22,7 @@ description: | `Hypertext Transfer Protocol (HTTP) `__ versions: + - 2.0.0 - 1.1.1 - 1.1.0 - 1.0.0 diff --git a/airflow/providers/opsgenie/CHANGELOG.rst b/airflow/providers/opsgenie/CHANGELOG.rst index dd5b1de2f0f31..bae90a5d42015 100644 --- a/airflow/providers/opsgenie/CHANGELOG.rst +++ b/airflow/providers/opsgenie/CHANGELOG.rst @@ -19,6 +19,16 @@ Changelog --------- +2.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +The OpsGenie Hook derives now from the new version of Http Hook in Http Provider. + +TODO: Add extra dependency to 2.0.0 version of http Provider. + 1.0.2 ..... diff --git a/airflow/providers/opsgenie/hooks/opsgenie_alert.py b/airflow/providers/opsgenie/hooks/opsgenie_alert.py index 60f1734fdccbc..4e1ff796198f3 100644 --- a/airflow/providers/opsgenie/hooks/opsgenie_alert.py +++ b/airflow/providers/opsgenie/hooks/opsgenie_alert.py @@ -18,9 +18,9 @@ # import json -from typing import Any, Optional +from typing import Any, Dict, Optional -import requests +import httpx from airflow.exceptions import AirflowException from airflow.providers.http.hooks.http import HttpHook @@ -54,20 +54,25 @@ def _get_api_key(self) -> str: ) return api_key - def get_conn(self, headers: Optional[dict] = None) -> requests.Session: + def get_conn( + self, headers: Optional[Dict[Any, Any]] = None, verify: bool = True, proxies=None, cert=None + ) -> httpx.Client: """ Overwrite HttpHook get_conn because this hook just needs base_url and headers, and does not need generic params :param headers: additional headers to be passed through as a dictionary :type headers: dict + :param verify: whether to verify SSL during the connection (only use for testing) + :param proxies: A dictionary mapping proxy keys to proxy + :param cert: client An SSL certificate used by the requested host + to authenticate the client. Either a path to an SSL certificate file, or + two-tuple of (certificate file, key file), or a three-tuple of (certificate + file, key file, password). """ - conn = self.get_connection(self.http_conn_id) - self.base_url = conn.host if conn.host else 'https://api.opsgenie.com' - session = requests.Session() - if headers: - session.headers.update(headers) - return session + client = super().get_conn(headers=headers, verify=verify, proxies=proxies, cert=cert) + self.base_url = self.conn.host if self.conn.host else 'https://api.opsgenie.com' + return client def execute(self, payload: Optional[dict] = None) -> Any: """ diff --git a/airflow/providers/opsgenie/provider.yaml b/airflow/providers/opsgenie/provider.yaml index c2e8d22546d10..f2a18ac6643b2 100644 --- a/airflow/providers/opsgenie/provider.yaml +++ b/airflow/providers/opsgenie/provider.yaml @@ -22,6 +22,7 @@ description: | `Opsgenie `__ versions: + - 2.0.0 - 1.0.2 - 1.0.1 - 1.0.0 diff --git a/setup.py b/setup.py index b057f51ec529c..61fe38635029a 100644 --- a/setup.py +++ b/setup.py @@ -343,20 +343,7 @@ def get_sphinx_theme_version() -> str: 'pyhive[hive]>=0.6.0', 'thrift>=0.9.2', ] -http = [ - 'requests>=2.20.0', -] -http_provider = [ - # NOTE ! The HTTP provider is NOT preinstalled by default when Airflow is installed - because it - # depends on `requests` library and until `chardet` is mandatory dependency of `requests` - # See https://github.com/psf/requests/pull/5797 - # This means that providers that depend on Http and cannot work without it, have to have - # explicit dependency on `apache-airflow-providers-http` which needs to be pulled in for them. - # Other cross-provider-dependencies are optional (usually cross-provider dependencies only enable - # some features of providers and majority of those providers works). They result with an extra, - # not with the `install-requires` dependency. - 'apache-airflow-providers-http', -] +http = [] jdbc = [ 'jaydebeapi>=1.1.1', ] @@ -543,7 +530,7 @@ def get_sphinx_theme_version() -> str: # Dict of all providers which are part of the Apache Airflow repository together with their requirements PROVIDERS_REQUIREMENTS: Dict[str, List[str]] = { - 'airbyte': http_provider, + 'airbyte': http, 'amazon': amazon, 'apache.beam': apache_beam, 'apache.cassandra': cassandra, @@ -551,7 +538,7 @@ def get_sphinx_theme_version() -> str: 'apache.hdfs': hdfs, 'apache.hive': hive, 'apache.kylin': kylin, - 'apache.livy': http_provider, + 'apache.livy': http, 'apache.pig': [], 'apache.pinot': pinot, 'apache.spark': spark, @@ -584,7 +571,7 @@ def get_sphinx_theme_version() -> str: 'neo4j': neo4j, 'odbc': odbc, 'openfaas': [], - 'opsgenie': http_provider, + 'opsgenie': http, 'oracle': oracle, 'pagerduty': pagerduty, 'papermill': papermill, diff --git a/tests/providers/airbyte/hooks/test_airbyte.py b/tests/providers/airbyte/hooks/test_airbyte.py index 09f10beffc255..45ba2b107c2d5 100644 --- a/tests/providers/airbyte/hooks/test_airbyte.py +++ b/tests/providers/airbyte/hooks/test_airbyte.py @@ -20,7 +20,6 @@ from unittest import mock import pytest -import requests_mock from airflow.exceptions import AirflowException from airflow.models import Connection @@ -34,12 +33,7 @@ class TestAirbyteHook(unittest.TestCase): """ airbyte_conn_id = 'airbyte_conn_id_test' - connection_id = 'conn_test_sync' - job_id = 1 - sync_connection_endpoint = 'http://test-airbyte:8001/api/v1/connections/sync' - get_job_endpoint = 'http://test-airbyte:8001/api/v1/jobs/get' - _mock_sync_conn_success_response_body = {'job': {'id': 1}} - _mock_job_status_success_response_body = {'job': {'status': 'succeeded'}} + job_id = '1' def setUp(self): db.merge_conn( @@ -54,22 +48,6 @@ def return_value_get_job(self, status): response.json.return_value = {'job': {'status': status}} return response - @requests_mock.mock() - def test_submit_sync_connection(self, m): - m.post( - self.sync_connection_endpoint, status_code=200, json=self._mock_sync_conn_success_response_body - ) - resp = self.hook.submit_sync_connection(connection_id=self.connection_id) - assert resp.status_code == 200 - assert resp.json() == self._mock_sync_conn_success_response_body - - @requests_mock.mock() - def test_get_job_status(self, m): - m.post(self.get_job_endpoint, status_code=200, json=self._mock_job_status_success_response_body) - resp = self.hook.get_job(job_id=self.job_id) - assert resp.status_code == 200 - assert resp.json() == self._mock_job_status_success_response_body - @mock.patch('airflow.providers.airbyte.hooks.airbyte.AirbyteHook.get_job') def test_wait_for_job_succeeded(self, mock_get_job): mock_get_job.side_effect = [self.return_value_get_job(self.hook.SUCCEEDED)] @@ -124,3 +102,38 @@ def test_wait_for_job_cancelled(self, mock_get_job): calls = [mock.call(job_id=self.job_id), mock.call(job_id=self.job_id)] assert mock_get_job.has_calls(calls) + + +@pytest.fixture +def setup_hook(): + yield AirbyteHook(airbyte_conn_id='airbyte_conn_id_test') + + +class TestAirbyteMockHttpx: + sync_connection_endpoint = 'http://test-airbyte:8001/api/v1/connections/sync' + get_job_endpoint = 'http://test-airbyte:8001/api/v1/jobs/get' + connection_id = 'conn_test_sync' + _mock_sync_conn_success_response_body = {'job': {'id': 1}} + _mock_job_status_success_response_body = {'job': {'status': 'succeeded'}} + + def test_submit_sync_connection(self, httpx_mock, setup_hook): + httpx_mock.add_response( + method='POST', + url=self.sync_connection_endpoint, + status_code=200, + json=self._mock_sync_conn_success_response_body, + ) + resp = setup_hook.submit_sync_connection(connection_id=self.connection_id) + assert resp.status_code == 200 + assert resp.json() == self._mock_sync_conn_success_response_body + + def test_get_job_status(self, httpx_mock, setup_hook): + httpx_mock.add_response( + method='POST', + url=self.get_job_endpoint, + status_code=200, + json=self._mock_job_status_success_response_body, + ) + resp = setup_hook.get_job(job_id='1') + assert resp.status_code == 200 + assert resp.json() == self._mock_job_status_success_response_body diff --git a/tests/providers/apache/livy/hooks/test_livy.py b/tests/providers/apache/livy/hooks/test_livy.py index 86b7acaa6fc39..ed9b6d27f66f8 100644 --- a/tests/providers/apache/livy/hooks/test_livy.py +++ b/tests/providers/apache/livy/hooks/test_livy.py @@ -14,14 +14,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# pylint: disable=unused-argument import json import unittest -from unittest.mock import patch +import httpx import pytest -import requests_mock -from requests.exceptions import RequestException from airflow.exceptions import AirflowException from airflow.models import Connection @@ -33,39 +32,6 @@ class TestLivyHook(unittest.TestCase): - @classmethod - def setUpClass(cls): - db.merge_conn( - Connection(conn_id='livy_default', conn_type='http', host='host', schema='http', port=8998) - ) - db.merge_conn(Connection(conn_id='default_port', conn_type='http', host='http://host')) - db.merge_conn(Connection(conn_id='default_protocol', conn_type='http', host='host')) - db.merge_conn(Connection(conn_id='port_set', host='host', conn_type='http', port=1234)) - db.merge_conn(Connection(conn_id='schema_set', host='host', conn_type='http', schema='zzz')) - db.merge_conn( - Connection(conn_id='dont_override_schema', conn_type='http', host='http://host', schema='zzz') - ) - db.merge_conn(Connection(conn_id='missing_host', conn_type='http', port=1234)) - db.merge_conn(Connection(conn_id='invalid_uri', uri='http://invalid_uri:4321')) - - def test_build_get_hook(self): - - connection_url_mapping = { - # id, expected - 'default_port': 'http://host', - 'default_protocol': 'http://host', - 'port_set': 'http://host:1234', - 'schema_set': 'zzz://host', - 'dont_override_schema': 'http://host', - } - - for conn_id, expected in connection_url_mapping.items(): - with self.subTest(conn_id): - hook = LivyHook(livy_conn_id=conn_id) - - hook.get_conn() - assert hook.base_url == expected - @unittest.skip("inherited HttpHook does not handle missing hostname") def test_missing_host(self): with pytest.raises(AirflowException): @@ -241,35 +207,81 @@ def test_validate_extra_conf(self): with pytest.raises(ValueError): LivyHook._validate_extra_conf({'has_val': 'val', 'no_val': ''}) - @patch('airflow.providers.apache.livy.hooks.livy.LivyHook.run_method') - def test_post_batch_arguments(self, mock_request): + def test_check_session_id(self): + with self.subTest('valid 00'): + try: + LivyHook._validate_session_id(100) + except TypeError: + self.fail("") - mock_request.return_value.status_code = 201 - mock_request.return_value.json.return_value = { - 'id': BATCH_ID, - 'state': BatchState.STARTING.value, - 'log': [], - } + with self.subTest('valid 01'): + try: + LivyHook._validate_session_id(0) + except TypeError: + self.fail("") - hook = LivyHook() - resp = hook.post_batch(file='sparkapp') + with self.subTest('None'): + with pytest.raises(TypeError): + LivyHook._validate_session_id(None) # noqa - mock_request.assert_called_once_with( - method='POST', endpoint='/batches', data=json.dumps({'file': 'sparkapp'}) - ) + with self.subTest('random string'): + with pytest.raises(TypeError): + LivyHook._validate_session_id('asd') + + +@pytest.fixture(scope="class") +def setup_connections(): + db.merge_conn(Connection(conn_id='livy_default', conn_type='http', host='host', schema='http', port=8998)) + db.merge_conn(Connection(conn_id='default_port', conn_type='http', host='http://host')) + db.merge_conn(Connection(conn_id='default_protocol', conn_type='http', host='host')) + db.merge_conn(Connection(conn_id='port_set', host='host', conn_type='http', port=1234)) + db.merge_conn(Connection(conn_id='schema_set', host='host', conn_type='http', schema='zzz')) + db.merge_conn( + Connection(conn_id='dont_override_schema', conn_type='http', host='http://host', schema='zzz') + ) + db.merge_conn(Connection(conn_id='missing_host', conn_type='http', port=1234)) + db.merge_conn(Connection(conn_id='invalid_uri', uri='http://invalid_uri:4321')) - request_args = mock_request.call_args[1] - assert 'data' in request_args - assert isinstance(request_args['data'], str) +class TestLivyMockHttpx: + def test_build_get_hook(self, setup_connections): + + connection_url_mapping = { + # id, expected + 'default_port': 'http://host', + 'default_protocol': 'http://host', + 'port_set': 'http://host:1234', + 'schema_set': 'zzz://host', + 'dont_override_schema': 'http://host', + } + + for conn_id, expected in connection_url_mapping.items(): + hook = LivyHook(livy_conn_id=conn_id) + + hook.get_conn() + assert hook.base_url == expected + + def test_post_batch_arguments(self, httpx_mock, setup_connections): + httpx_mock.add_response( + method='POST', + url="http://livy:8998/batches", + status_code=201, + match_content=json.dumps({'file': 'sparkapp'}).encode("utf-8"), + json={ + 'id': BATCH_ID, + 'state': BatchState.STARTING.value, + 'log': [], + }, + ) + hook = LivyHook() + resp = hook.post_batch(file='sparkapp') assert isinstance(resp, int) assert resp == BATCH_ID - @requests_mock.mock() - def test_post_batch_success(self, mock): - mock.register_uri( - 'POST', - '//livy:8998/batches', + def test_post_batch_success(self, httpx_mock, setup_connections): + httpx_mock.add_response( + method='POST', + url="http://livy:8998/batches", json={'id': BATCH_ID, 'state': BatchState.STARTING.value, 'log': []}, status_code=201, ) @@ -279,17 +291,21 @@ def test_post_batch_success(self, mock): assert isinstance(resp, int) assert resp == BATCH_ID - @requests_mock.mock() - def test_post_batch_fail(self, mock): - mock.register_uri('POST', '//livy:8998/batches', json={}, status_code=400, reason='ERROR') - + def test_post_batch_fail(self, httpx_mock, setup_connections): + httpx_mock.add_response( + method='POST', + url="http://livy:8998/batches", + json={}, + status_code=400, + ) hook = LivyHook() with pytest.raises(AirflowException): hook.post_batch(file='sparkapp') - @requests_mock.mock() - def test_get_batch_success(self, mock): - mock.register_uri('GET', f'//livy:8998/batches/{BATCH_ID}', json={'id': BATCH_ID}, status_code=200) + def test_get_batch_success(self, httpx_mock, setup_connections): + httpx_mock.add_response( + method='GET', url=f'http://livy:8998/batches/{BATCH_ID}', json={'id': BATCH_ID}, status_code=200 + ) hook = LivyHook() resp = hook.get_batch(BATCH_ID) @@ -297,33 +313,28 @@ def test_get_batch_success(self, mock): assert isinstance(resp, dict) assert 'id' in resp - @requests_mock.mock() - def test_get_batch_fail(self, mock): - mock.register_uri( - 'GET', - f'//livy:8998/batches/{BATCH_ID}', + def test_get_batch_fail(self, httpx_mock, setup_connections): + httpx_mock.add_response( + method='GET', + url=f'http://livy:8998/batches/{BATCH_ID}', json={'msg': 'Unable to find batch'}, status_code=404, - reason='ERROR', ) hook = LivyHook() with pytest.raises(AirflowException): hook.get_batch(BATCH_ID) - def test_invalid_uri(self): + def test_invalid_uri(self, setup_connections): hook = LivyHook(livy_conn_id='invalid_uri') - with pytest.raises(RequestException): + with pytest.raises(httpx.ConnectError): hook.post_batch(file='sparkapp') - @requests_mock.mock() - def test_get_batch_state_success(self, mock): - + def test_get_batch_state_success(self, httpx_mock, setup_connections): running = BatchState.RUNNING - - mock.register_uri( - 'GET', - f'//livy:8998/batches/{BATCH_ID}/state', + httpx_mock.add_response( + method='GET', + url=f'http://livy:8998/batches/{BATCH_ID}/state', json={'id': BATCH_ID, 'state': running.value}, status_code=200, ) @@ -333,116 +344,98 @@ def test_get_batch_state_success(self, mock): assert isinstance(state, BatchState) assert state == running - @requests_mock.mock() - def test_get_batch_state_fail(self, mock): - mock.register_uri( - 'GET', f'//livy:8998/batches/{BATCH_ID}/state', json={}, status_code=400, reason='ERROR' + def test_get_batch_state_fail(self, httpx_mock, setup_connections): + httpx_mock.add_response( + method='GET', url=f'http://livy:8998/batches/{BATCH_ID}/state', json={}, status_code=400 ) hook = LivyHook() with pytest.raises(AirflowException): hook.get_batch_state(BATCH_ID) - @requests_mock.mock() - def test_get_batch_state_missing(self, mock): - mock.register_uri('GET', f'//livy:8998/batches/{BATCH_ID}/state', json={}, status_code=200) + def test_get_batch_state_missing(self, httpx_mock, setup_connections): + httpx_mock.add_response( + method='GET', url=f'http://livy:8998/batches/{BATCH_ID}/state', json={}, status_code=200 + ) hook = LivyHook() with pytest.raises(AirflowException): hook.get_batch_state(BATCH_ID) - def test_parse_post_response(self): + def test_parse_post_response(self, setup_connections): res_id = LivyHook._parse_post_response({'id': BATCH_ID, 'log': []}) assert BATCH_ID == res_id - @requests_mock.mock() - def test_delete_batch_success(self, mock): - mock.register_uri( - 'DELETE', f'//livy:8998/batches/{BATCH_ID}', json={'msg': 'deleted'}, status_code=200 + def test_delete_batch_success(self, httpx_mock, setup_connections): + httpx_mock.add_response( + method='DELETE', + url=f'http://livy:8998/batches/{BATCH_ID}', + json={'msg': 'deleted'}, + status_code=200, ) resp = LivyHook().delete_batch(BATCH_ID) assert resp == {'msg': 'deleted'} - @requests_mock.mock() - def test_delete_batch_fail(self, mock): - mock.register_uri( - 'DELETE', f'//livy:8998/batches/{BATCH_ID}', json={}, status_code=400, reason='ERROR' + def test_delete_batch_fail(self, httpx_mock, setup_connections): + httpx_mock.add_response( + method='DELETE', url=f'http://livy:8998/batches/{BATCH_ID}', json={}, status_code=400 ) hook = LivyHook() with pytest.raises(AirflowException): hook.delete_batch(BATCH_ID) - @requests_mock.mock() - def test_missing_batch_id(self, mock): - mock.register_uri('POST', '//livy:8998/batches', json={}, status_code=201) - + def test_missing_batch_id(self, httpx_mock, setup_connections): + httpx_mock.add_response(method='POST', url='http://livy:8998/batches', json={}, status_code=201) hook = LivyHook() with pytest.raises(AirflowException): hook.post_batch(file='sparkapp') - @requests_mock.mock() - def test_get_batch_validation(self, mock): - mock.register_uri('GET', f'//livy:8998/batches/{BATCH_ID}', json=SAMPLE_GET_RESPONSE, status_code=200) + def test_get_batch_validation(self, httpx_mock, setup_connections): + httpx_mock.add_response( + method='GET', + url=f'http://livy:8998/batches/{BATCH_ID}', + json=SAMPLE_GET_RESPONSE, + status_code=200, + ) hook = LivyHook() - with self.subTest('get_batch'): - hook.get_batch(BATCH_ID) + hook.get_batch(BATCH_ID) # make sure blocked by validation for val in [None, 'one', {'a': 'b'}]: - with self.subTest(f'get_batch {val}'): - with pytest.raises(TypeError): - hook.get_batch(val) - - @requests_mock.mock() - def test_get_batch_state_validation(self, mock): - mock.register_uri( - 'GET', f'//livy:8998/batches/{BATCH_ID}/state', json=SAMPLE_GET_RESPONSE, status_code=200 + with pytest.raises(TypeError): + hook.get_batch(val) + + def test_get_batch_state_validation(self, httpx_mock, setup_connections): + httpx_mock.add_response( + method='GET', + url=f'http://livy:8998/batches/{BATCH_ID}/state', + json=SAMPLE_GET_RESPONSE, + status_code=200, ) hook = LivyHook() - with self.subTest('get_batch'): - hook.get_batch_state(BATCH_ID) + hook.get_batch_state(BATCH_ID) for val in [None, 'one', {'a': 'b'}]: - with self.subTest(f'get_batch {val}'): - with pytest.raises(TypeError): - hook.get_batch_state(val) + with pytest.raises(TypeError): + hook.get_batch_state(val) - @requests_mock.mock() - def test_delete_batch_validation(self, mock): - mock.register_uri('DELETE', f'//livy:8998/batches/{BATCH_ID}', json={'id': BATCH_ID}, status_code=200) + def test_delete_batch_validation(self, httpx_mock, setup_connections): + httpx_mock.add_response( + method='DELETE', + url=f'http://livy:8998/batches/{BATCH_ID}', + json={'id': BATCH_ID}, + status_code=200, + ) hook = LivyHook() - with self.subTest('get_batch'): - hook.delete_batch(BATCH_ID) + hook.delete_batch(BATCH_ID) for val in [None, 'one', {'a': 'b'}]: - with self.subTest(f'get_batch {val}'): - with pytest.raises(TypeError): - hook.delete_batch(val) - - def test_check_session_id(self): - with self.subTest('valid 00'): - try: - LivyHook._validate_session_id(100) - except TypeError: - self.fail("") - - with self.subTest('valid 01'): - try: - LivyHook._validate_session_id(0) - except TypeError: - self.fail("") - - with self.subTest('None'): with pytest.raises(TypeError): - LivyHook._validate_session_id(None) # noqa - - with self.subTest('random string'): - with pytest.raises(TypeError): - LivyHook._validate_session_id('asd') + hook.delete_batch(val) diff --git a/tests/providers/http/hooks/test_http.py b/tests/providers/http/hooks/test_http.py index 825847b982156..3a0a20714207e 100644 --- a/tests/providers/http/hooks/test_http.py +++ b/tests/providers/http/hooks/test_http.py @@ -17,15 +17,11 @@ # under the License. import json import os -import unittest -from collections import OrderedDict from unittest import mock +import httpx import pytest -import requests -import requests_mock import tenacity -from parameterized import parameterized from airflow.exceptions import AirflowException from airflow.models import Connection @@ -40,41 +36,42 @@ def get_airflow_connection_with_port(unused_conn_id=None): return Connection(conn_id='http_default', conn_type='http', host='test.com', port=1234) -class TestHttpHook(unittest.TestCase): - """Test get, post and raise_for_status""" +@pytest.fixture +def setup_hook(): + yield HttpHook(method='GET') + + +@pytest.fixture +def setup_lowercase_hook(): + yield HttpHook(method='get') - def setUp(self): - session = requests.Session() - adapter = requests_mock.Adapter() - session.mount('mock', adapter) - self.get_hook = HttpHook(method='GET') - self.get_lowercase_hook = HttpHook(method='get') - self.post_hook = HttpHook(method='POST') - @requests_mock.mock() - def test_raise_for_status_with_200(self, m): +@pytest.fixture +def setup_post_hook(): + yield HttpHook(method='POST') + + +class TestHttpHook: + """Test get, post and raise_for_status""" - m.get('http://test:8080/v1/test', status_code=200, text='{"status":{"status": 200}}', reason='OK') + def test_raise_for_status_with_200(self, httpx_mock, setup_hook): + httpx_mock.add_response( + url='http://test:8080/v1/test', status_code=200, data='{"status":{"status": 200}}' + ) with mock.patch('airflow.hooks.base.BaseHook.get_connection', side_effect=get_airflow_connection): - resp = self.get_hook.run('v1/test') + resp = setup_hook.run('v1/test') assert resp.text == '{"status":{"status": 200}}' - @requests_mock.mock() - @mock.patch('requests.Session') - @mock.patch('requests.Request') - def test_get_request_with_port(self, mock_requests, request_mock, mock_session): - from requests.exceptions import MissingSchema - + @mock.patch('httpx.Client') + @mock.patch('httpx.Request') + def test_get_request_with_port(self, request_mock, mock_client, setup_hook): with mock.patch( 'airflow.hooks.base.BaseHook.get_connection', side_effect=get_airflow_connection_with_port ): expected_url = 'http://test.com:1234/some/endpoint' for endpoint in ['some/endpoint', '/some/endpoint']: - try: - self.get_hook.run(endpoint) - except MissingSchema: - pass + setup_hook.run(endpoint) request_mock.assert_called_once_with( mock.ANY, expected_url, headers=mock.ANY, params=mock.ANY @@ -82,117 +79,116 @@ def test_get_request_with_port(self, mock_requests, request_mock, mock_session): request_mock.reset_mock() - @requests_mock.mock() - def test_get_request_do_not_raise_for_status_if_check_response_is_false(self, m): - - m.get( - 'http://test:8080/v1/test', + def test_get_request_do_not_raise_for_status_if_check_response_is_false(self, httpx_mock, setup_hook): + httpx_mock.add_response( + url='http://test:8080/v1/test', status_code=404, - text='{"status":{"status": 404}}', - reason='Bad request', + data='{"status":{"status": 404}}', ) with mock.patch('airflow.hooks.base.BaseHook.get_connection', side_effect=get_airflow_connection): - resp = self.get_hook.run('v1/test', extra_options={'check_response': False}) + resp = setup_hook.run('v1/test', extra_options={'check_response': False}) assert resp.text == '{"status":{"status": 404}}' - @requests_mock.mock() - def test_hook_contains_header_from_extra_field(self, mock_requests): + def test_hook_contains_header_from_extra_field(self, setup_hook): with mock.patch('airflow.hooks.base.BaseHook.get_connection', side_effect=get_airflow_connection): expected_conn = get_airflow_connection() - conn = self.get_hook.get_conn() + conn = setup_hook.get_conn() assert dict(conn.headers, **json.loads(expected_conn.extra)) == conn.headers assert conn.headers.get('bareer') == 'test' - @requests_mock.mock() - @mock.patch('requests.Request') - def test_hook_with_method_in_lowercase(self, mock_requests, request_mock): - from requests.exceptions import InvalidURL, MissingSchema - + def test_hook_with_method_in_lowercase(self, httpx_mock, setup_lowercase_hook): + data = "test_params=aaaa" + httpx_mock.add_response( + url='http://test.com:1234/v1/test?test_params=aaaa', + status_code=200, + data='{"status":{"status": 200}}', + ) with mock.patch( 'airflow.hooks.base.BaseHook.get_connection', side_effect=get_airflow_connection_with_port ): - data = "test params" try: - self.get_lowercase_hook.run('v1/test', data=data) - except (MissingSchema, InvalidURL): + setup_lowercase_hook.run('v1/test', data=data) + except ConnectionError: pass - request_mock.assert_called_once_with(mock.ANY, mock.ANY, headers=mock.ANY, params=data) - @requests_mock.mock() - def test_hook_uses_provided_header(self, mock_requests): - conn = self.get_hook.get_conn(headers={"bareer": "newT0k3n"}) + def test_hook_uses_provided_header(self, setup_hook): + conn = setup_hook.get_conn(headers={"bareer": "newT0k3n"}) assert conn.headers.get('bareer') == "newT0k3n" - @requests_mock.mock() - def test_hook_has_no_header_from_extra(self, mock_requests): - conn = self.get_hook.get_conn() + def test_hook_has_no_header_from_extra(self, setup_hook): + conn = setup_hook.get_conn() assert conn.headers.get('bareer') is None - @requests_mock.mock() - def test_hooks_header_from_extra_is_overridden(self, mock_requests): + def test_hooks_header_from_extra_is_overridden(self, setup_hook): with mock.patch('airflow.hooks.base.BaseHook.get_connection', side_effect=get_airflow_connection): - conn = self.get_hook.get_conn(headers={"bareer": "newT0k3n"}) - assert conn.headers.get('bareer') == 'newT0k3n' + conn = setup_hook.get_conn() + assert conn.headers.get('bareer') == 'test' - @requests_mock.mock() - def test_post_request(self, mock_requests): - mock_requests.post( - 'http://test:8080/v1/test', status_code=200, text='{"status":{"status": 200}}', reason='OK' + def test_hooks_header_from_extra_is_overridden_and_used(self, httpx_mock, setup_hook): + httpx_mock.add_response( + url='http://test:8080/v1/test', + status_code=200, + data='{"status":{"status": 200}}', + match_headers={"bareer": "test"}, + ) + with mock.patch('airflow.hooks.base.BaseHook.get_connection', side_effect=get_airflow_connection): + setup_hook.run('v1/test', extra_options={'check_response': False}) + + def test_post_request(self, httpx_mock, setup_post_hook): + httpx_mock.add_response( + method='POST', url='http://test:8080/v1/test', status_code=200, data='{"status":{"status": 200}}' ) with mock.patch('airflow.hooks.base.BaseHook.get_connection', side_effect=get_airflow_connection): - resp = self.post_hook.run('v1/test') + resp = setup_post_hook.run('v1/test') assert resp.status_code == 200 - @requests_mock.mock() - def test_post_request_with_error_code(self, mock_requests): - mock_requests.post( - 'http://test:8080/v1/test', + def test_post_request_with_error_code(self, httpx_mock, setup_post_hook): + httpx_mock.add_response( + method='POST', + url='http://test:8080/v1/test', status_code=418, - text='{"status":{"status": 418}}', - reason='I\'m a teapot', ) with mock.patch('airflow.hooks.base.BaseHook.get_connection', side_effect=get_airflow_connection): with pytest.raises(AirflowException): - self.post_hook.run('v1/test') - - @requests_mock.mock() - def test_post_request_do_not_raise_for_status_if_check_response_is_false(self, mock_requests): - mock_requests.post( - 'http://test:8080/v1/test', + setup_post_hook.run('v1/test') + + def test_post_request_do_not_raise_for_status_if_check_response_is_false( + self, httpx_mock, setup_post_hook + ): + httpx_mock.add_response( + method='POST', + url='http://test:8080/v1/test', status_code=418, - text='{"status":{"status": 418}}', - reason='I\'m a teapot', ) with mock.patch('airflow.hooks.base.BaseHook.get_connection', side_effect=get_airflow_connection): - resp = self.post_hook.run('v1/test', extra_options={'check_response': False}) + resp = setup_post_hook.run('v1/test', extra_options={'check_response': False}) assert resp.status_code == 418 - @mock.patch('airflow.providers.http.hooks.http.requests.Session') - def test_retry_on_conn_error(self, mocked_session): + @mock.patch('airflow.providers.http.hooks.http.httpx.Client') + def test_retry_on_conn_error(self, mocked_client, setup_hook): retry_args = dict( wait=tenacity.wait_none(), stop=tenacity.stop_after_attempt(7), - retry=tenacity.retry_if_exception_type(requests.exceptions.ConnectionError), + retry=tenacity.retry_if_exception_type(httpx.NetworkError), ) def send_and_raise(unused_request, **kwargs): - raise requests.exceptions.ConnectionError + raise httpx.NetworkError(message="ConnectionError") - mocked_session().send.side_effect = send_and_raise + mocked_client().send.side_effect = send_and_raise # The job failed for some reason with pytest.raises(tenacity.RetryError): - self.get_hook.run_with_advanced_retry(endpoint='v1/test', _retry_args=retry_args) - assert self.get_hook._retry_obj.stop.max_attempt_number + 1 == mocked_session.call_count + setup_hook.run_with_advanced_retry(endpoint='v1/test', _retry_args=retry_args) + assert setup_hook._retry_obj.stop.max_attempt_number + 1 == mocked_client.call_count - @requests_mock.mock() - def test_run_with_advanced_retry(self, m): + def test_run_with_advanced_retry(self, httpx_mock, setup_hook): - m.get('http://test:8080/v1/test', status_code=200, reason='OK') + httpx_mock.add_response(url='http://test:8080/v1/test', status_code=200) retry_args = dict( wait=tenacity.wait_none(), @@ -201,22 +197,18 @@ def test_run_with_advanced_retry(self, m): reraise=True, ) with mock.patch('airflow.hooks.base.BaseHook.get_connection', side_effect=get_airflow_connection): - response = self.get_hook.run_with_advanced_retry(endpoint='v1/test', _retry_args=retry_args) - assert isinstance(response, requests.Response) - - def test_header_from_extra_and_run_method_are_merged(self): - def run_and_return(unused_session, prepped_request, unused_extra_options, **kwargs): - return prepped_request - + response = setup_hook.run_with_advanced_retry(endpoint='v1/test', _retry_args=retry_args) + assert isinstance(response, httpx.Response) + + def test_header_from_extra_and_run_method_are_merged(self, httpx_mock, setup_hook): + httpx_mock.add_response( + url='http://test:8080/v1/test', + status_code=200, + match_headers={"bareer": "test", "some_other_header": "test"}, + ) # The job failed for some reason - with mock.patch( - 'airflow.providers.http.hooks.http.HttpHook.run_and_check', side_effect=run_and_return - ): - with mock.patch('airflow.hooks.base.BaseHook.get_connection', side_effect=get_airflow_connection): - prepared_request = self.get_hook.run('v1/test', headers={'some_other_header': 'test'}) - actual = dict(prepared_request.headers) - assert actual.get('bareer') == 'test' - assert actual.get('some_other_header') == 'test' + with mock.patch('airflow.hooks.base.BaseHook.get_connection', side_effect=get_airflow_connection): + setup_hook.run('v1/test', headers={'some_other_header': 'test'}) @mock.patch('airflow.providers.http.hooks.http.HttpHook.get_connection') def test_http_connection(self, mock_get_connection): @@ -250,8 +242,8 @@ def test_host_encoded_https_connection(self, mock_get_connection): hook.get_conn({}) assert hook.base_url == 'https://localhost' - def test_method_converted_to_uppercase_when_created_in_lowercase(self): - assert self.get_lowercase_hook.method == 'GET' + def test_method_converted_to_uppercase_when_created_in_lowercase(self, setup_lowercase_hook): + assert setup_lowercase_hook.method == 'GET' @mock.patch('airflow.providers.http.hooks.http.HttpHook.get_connection') def test_connection_without_host(self, mock_get_connection): @@ -262,96 +254,90 @@ def test_connection_without_host(self, mock_get_connection): hook.get_conn({}) assert hook.base_url == 'http://' - @parameterized.expand( - [ - 'GET', - 'POST', - ] - ) - @requests_mock.mock() - def test_json_request(self, method, mock_requests): + def test_json_request_get(self, httpx_mock): obj1 = {'a': 1, 'b': 'abc', 'c': [1, 2, {"d": 10}]} - def match_obj1(request): - return request.json() == obj1 + httpx_mock.add_response( + method='GET', url='http://test:8080/v1/test', match_content=json.dumps(obj1).encode('utf-8') + ) - mock_requests.request(method=method, url='//test:8080/v1/test', additional_matcher=match_obj1) + with mock.patch('airflow.hooks.base.BaseHook.get_connection', side_effect=get_airflow_connection): + # will raise NoMockAddress exception if obj1 != request.json() + HttpHook(method='GET').run('v1/test', json=obj1) + + def test_json_request_post(self, httpx_mock): + obj1 = {'a': 1, 'b': 'abc', 'c': [1, 2, {"d": 10}]} + + httpx_mock.add_response( + method='POST', url='http://test:8080/v1/test', match_content=json.dumps(obj1).encode('utf-8') + ) with mock.patch('airflow.hooks.base.BaseHook.get_connection', side_effect=get_airflow_connection): # will raise NoMockAddress exception if obj1 != request.json() - HttpHook(method=method).run('v1/test', json=obj1) + HttpHook(method='POST').run('v1/test', json=obj1) - @mock.patch('airflow.providers.http.hooks.http.requests.Session.send') - def test_verify_set_to_true_by_default(self, mock_session_send): + @mock.patch('airflow.providers.http.hooks.http.httpx.Client') + def test_verify_set_to_true_by_default(self, mock_client, setup_hook): with mock.patch( 'airflow.hooks.base.BaseHook.get_connection', side_effect=get_airflow_connection_with_port ): - self.get_hook.run('/some/endpoint') - mock_session_send.assert_called_once_with( - mock.ANY, - allow_redirects=True, - cert=None, - proxies=OrderedDict(), - stream=False, - timeout=None, - verify=True, + setup_hook.run('/some/endpoint') + mock_client.assert_called_once_with(verify=True, cert=None, proxies=None) + + @mock.patch('airflow.providers.http.hooks.http.httpx.Client') + def test_verify_can_be_overridden(self, mock_client, setup_hook): + with mock.patch('airflow.hooks.base.BaseHook.get_connection', side_effect=get_airflow_connection): + setup_hook.run('/some/endpoint', extra_options={'verify': False}) + mock_client.assert_called_once_with(verify=False, cert=None, proxies=None) + + @mock.patch('airflow.providers.http.hooks.http.httpx.Client') + def test_cert_can_be_overridden(self, mock_client, setup_hook): + with mock.patch('airflow.hooks.base.BaseHook.get_connection', side_effect=get_airflow_connection): + setup_hook.run('/some/endpoint', extra_options={'cert': '/tmp/private.crt'}) + mock_client.assert_called_once_with(verify=True, cert='/tmp/private.crt', proxies=None) + + @mock.patch('airflow.providers.http.hooks.http.httpx.Client') + def test_proxies_can_be_overridden(self, mock_client, setup_hook): + with mock.patch('airflow.hooks.base.BaseHook.get_connection', side_effect=get_airflow_connection): + setup_hook.run('/some/endpoint', extra_options={'proxies': {"http://localhost": 'http://proxy'}}) + mock_client.assert_called_once_with( + verify=True, cert=None, proxies={"http://localhost": 'http://proxy'} ) - @mock.patch('airflow.providers.http.hooks.http.requests.Session.send') - @mock.patch.dict(os.environ, {"REQUESTS_CA_BUNDLE": "/tmp/test.crt"}) - def test_requests_ca_bundle_env_var(self, mock_session_send): + @mock.patch('airflow.providers.http.hooks.http.httpx.Client') + def test_verifu_parameter_set(self, mock_client, setup_hook): with mock.patch( 'airflow.hooks.base.BaseHook.get_connection', side_effect=get_airflow_connection_with_port ): - self.get_hook.run('/some/endpoint') - - mock_session_send.assert_called_once_with( - mock.ANY, - allow_redirects=True, - cert=None, - proxies=OrderedDict(), - stream=False, - timeout=None, - verify='/tmp/test.crt', - ) + setup_hook.run('/some/endpoint', extra_options={'verify': '/tmp/overridden.crt'}) + mock_client.assert_called_once_with(cert=None, proxies=None, verify='/tmp/overridden.crt') - @mock.patch('airflow.providers.http.hooks.http.requests.Session.send') + @mock.patch('airflow.providers.http.hooks.http.httpx.Client') @mock.patch.dict(os.environ, {"REQUESTS_CA_BUNDLE": "/tmp/test.crt"}) - def test_verify_respects_requests_ca_bundle_env_var(self, mock_session_send): + def test_requests_ca_bundle_env_var(self, mock_client, setup_hook): with mock.patch( 'airflow.hooks.base.BaseHook.get_connection', side_effect=get_airflow_connection_with_port ): - self.get_hook.run('/some/endpoint', extra_options={'verify': True}) - - mock_session_send.assert_called_once_with( - mock.ANY, - allow_redirects=True, - cert=None, - proxies=OrderedDict(), - stream=False, - timeout=None, - verify='/tmp/test.crt', - ) + setup_hook.run('/some/endpoint') + mock_client.assert_called_once_with(cert=None, proxies=None, verify='/tmp/test.crt') - @mock.patch('airflow.providers.http.hooks.http.requests.Session.send') + @mock.patch('airflow.providers.http.hooks.http.httpx.Client') @mock.patch.dict(os.environ, {"REQUESTS_CA_BUNDLE": "/tmp/test.crt"}) - def test_verify_false_parameter_overwrites_set_requests_ca_bundle_env_var(self, mock_session_send): + def test_verify_respects_requests_ca_bundle_env_var(self, mock_client, setup_hook): with mock.patch( 'airflow.hooks.base.BaseHook.get_connection', side_effect=get_airflow_connection_with_port ): - self.get_hook.run('/some/endpoint', extra_options={'verify': False}) - - mock_session_send.assert_called_once_with( - mock.ANY, - allow_redirects=True, - cert=None, - proxies=OrderedDict(), - stream=False, - timeout=None, - verify=False, - ) + setup_hook.run('/some/endpoint', extra_options={'verify': True}) + mock_client.assert_called_once_with(cert=None, proxies=None, verify='/tmp/test.crt') -send_email_test = mock.Mock() + @mock.patch('airflow.providers.http.hooks.http.httpx.Client') + @mock.patch.dict(os.environ, {"REQUESTS_CA_BUNDLE": "/tmp/test.crt"}) + def test_verify_false_parameter_overwrites_set_requests_ca_bundle_env_var(self, mock_client, setup_hook): + with mock.patch( + 'airflow.hooks.base.BaseHook.get_connection', side_effect=get_airflow_connection_with_port + ): + setup_hook.run('/some/endpoint', extra_options={'verify': False}) + mock_client.assert_called_once_with(cert=None, proxies=None, verify=False) diff --git a/tests/providers/http/operators/test_http.py b/tests/providers/http/operators/test_http.py index d4c622c29e6a5..dc62c0371c740 100644 --- a/tests/providers/http/operators/test_http.py +++ b/tests/providers/http/operators/test_http.py @@ -16,26 +16,23 @@ # specific language governing permissions and limitations # under the License. -import unittest from unittest import mock import pytest -import requests_mock from airflow.exceptions import AirflowException from airflow.providers.http.operators.http import SimpleHttpOperator @mock.patch.dict('os.environ', AIRFLOW_CONN_HTTP_EXAMPLE='http://www.example.com') -class TestSimpleHttpOp(unittest.TestCase): - @requests_mock.mock() - def test_response_in_logs(self, m): +class TestSimpleHttpOp: + def test_response_in_logs(self, httpx_mock): """ Test that when using SimpleHttpOperator with 'GET', the log contains 'Example Domain' in it """ - m.get('http://www.example.com', text='Example.com fake response') + httpx_mock.add_response(url='http://www.example.com', data='Example.com fake response') operator = SimpleHttpOperator( task_id='test_HTTP_op', method='GET', @@ -49,8 +46,7 @@ def test_response_in_logs(self, m): calls = [mock.call('Example.com fake response'), mock.call('Example.com fake response')] mock_info.has_calls(calls) - @requests_mock.mock() - def test_response_in_logs_after_failed_check(self, m): + def test_response_in_logs_after_failed_check(self, httpx_mock): """ Test that when using SimpleHttpOperator with log_response=True, the response is logged even if request_check fails @@ -59,7 +55,7 @@ def test_response_in_logs_after_failed_check(self, m): def response_check(response): return response.text != 'invalid response' - m.get('http://www.example.com', text='invalid response') + httpx_mock.add_response(url='http://www.example.com', data='invalid response') operator = SimpleHttpOperator( task_id='test_HTTP_op', method='GET', @@ -75,9 +71,8 @@ def response_check(response): calls = [mock.call('Calling HTTP method'), mock.call('invalid response')] mock_info.assert_has_calls(calls, any_order=True) - @requests_mock.mock() - def test_filters_response(self, m): - m.get('http://www.example.com', json={'value': 5}) + def test_filters_response(self, httpx_mock): + httpx_mock.add_response(url='http://www.example.com', json={'value': 5}) operator = SimpleHttpOperator( task_id='test_HTTP_op', method='GET', diff --git a/tests/providers/http/sensors/test_http.py b/tests/providers/http/sensors/test_http.py index 23ac2fdf2f69f..595deab718cda 100644 --- a/tests/providers/http/sensors/test_http.py +++ b/tests/providers/http/sensors/test_http.py @@ -15,12 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import unittest from unittest import mock -from unittest.mock import patch import pytest -import requests from airflow.exceptions import AirflowException, AirflowSensorTimeout from airflow.models import TaskInstance @@ -34,23 +31,22 @@ TEST_DAG_ID = 'unit_test_dag' -class TestHttpSensor(unittest.TestCase): - def setUp(self): - args = {'owner': 'airflow', 'start_date': DEFAULT_DATE} - self.dag = DAG(TEST_DAG_ID, default_args=args) +@pytest.fixture +def setup_dag(): + args = {'owner': 'airflow', 'start_date': DEFAULT_DATE} + yield DAG(TEST_DAG_ID, default_args=args) - @patch("airflow.providers.http.hooks.http.requests.Session.send") - def test_poke_exception(self, mock_session_send): + +class TestHttpSensor: + def test_poke_exception(self, httpx_mock): """ Exception occurs in poke function should not be ignored. """ - response = requests.Response() - response.status_code = 200 - mock_session_send.return_value = response def resp_check(_): raise AirflowException('AirflowException raised here!') + httpx_mock.add_response(status_code=200) task = HttpSensor( task_id='http_sensor_poke_exception', http_conn_id='http_default', @@ -63,19 +59,14 @@ def resp_check(_): with pytest.raises(AirflowException, match='AirflowException raised here!'): task.execute(context={}) - @patch("airflow.providers.http.hooks.http.requests.Session.send") - def test_poke_continues_for_http_500_with_extra_options_check_response_false(self, mock_session_send): + def test_poke_continues_for_http_500_with_extra_options_check_response_false(self, httpx_mock, setup_dag): def resp_check(_): return False - response = requests.Response() - response.status_code = 500 - response.reason = 'Internal Server Error' - response._content = b'Internal Server Error' - mock_session_send.return_value = response + httpx_mock.add_response(status_code=500, data="Internal Server Error") task = HttpSensor( - dag=self.dag, + dag=setup_dag, task_id='http_sensor_poke_for_code_500', http_conn_id='http_default', endpoint='', @@ -87,16 +78,16 @@ def resp_check(_): poke_interval=1, ) - with self.assertRaises(AirflowSensorTimeout): + with pytest.raises(AirflowSensorTimeout, match='Snap. Time is OUT. DAG id: unit_test_dag'): task.execute(context={}) - @patch("airflow.providers.http.hooks.http.requests.Session.send") - def test_head_method(self, mock_session_send): + def test_head_method(self, httpx_mock, setup_dag): def resp_check(_): return True + httpx_mock.add_response(status_code=200, url='https://www.httpbin.org', method='HEAD') task = HttpSensor( - dag=self.dag, + dag=setup_dag, task_id='http_sensor_head_method', http_conn_id='http_default', endpoint='', @@ -109,19 +100,8 @@ def resp_check(_): task.execute(context={}) - args, kwargs = mock_session_send.call_args - received_request = args[0] - - prep_request = requests.Request('HEAD', 'https://www.httpbin.org', {}).prepare() - - assert prep_request.url == received_request.url - assert prep_request.method, received_request.method - - @patch("airflow.providers.http.hooks.http.requests.Session.send") - def test_poke_context(self, mock_session_send): - response = requests.Response() - response.status_code = 200 - mock_session_send.return_value = response + def test_poke_context(self, httpx_mock, setup_dag): + httpx_mock.add_response(status_code=200, url='https://www.httpbin.org') def resp_check(_, execution_date): if execution_date == DEFAULT_DATE: @@ -136,25 +116,19 @@ def resp_check(_, execution_date): response_check=resp_check, timeout=5, poke_interval=1, - dag=self.dag, + dag=setup_dag, ) task_instance = TaskInstance(task=task, execution_date=DEFAULT_DATE) task.execute(task_instance.get_template_context()) - @patch("airflow.providers.http.hooks.http.requests.Session.send") - def test_logging_head_error_request(self, mock_session_send): + def test_logging_head_error_request(self, httpx_mock, setup_dag): def resp_check(_): return True - response = requests.Response() - response.status_code = 404 - response.reason = 'Not Found' - response._content = b"This endpoint doesn't exist" - mock_session_send.return_value = response - + httpx_mock.add_response(status_code=404, data="This endpoint doesn't exist") task = HttpSensor( - dag=self.dag, + dag=setup_dag, task_id='http_sensor_head_method', http_conn_id='http_default', endpoint='', @@ -167,7 +141,7 @@ def resp_check(_): with mock.patch.object(task.hook.log, 'error') as mock_errors: with pytest.raises(AirflowSensorTimeout): - task.execute(None) + task.execute(context={}) assert mock_errors.called calls = [ @@ -187,68 +161,51 @@ def resp_check(_): mock_errors.assert_has_calls(calls) -class FakeSession: - def __init__(self): - self.response = requests.Response() - self.response.status_code = 200 - self.response._content = 'apache/airflow'.encode('ascii', 'ignore') - - def send(self, *args, **kwargs): - return self.response +@pytest.fixture +def setup_op_dag(): + args = {'owner': 'airflow', 'start_date': DEFAULT_DATE_ISO} + yield DAG(TEST_DAG_ID, default_args=args) - def prepare_request(self, request): - if 'date' in request.params: - self.response._content += ('/' + request.params['date']).encode('ascii', 'ignore') - return self.response - def merge_environment_settings(self, _url, **kwargs): - return kwargs - - -class TestHttpOpSensor(unittest.TestCase): - def setUp(self): - args = {'owner': 'airflow', 'start_date': DEFAULT_DATE_ISO} - dag = DAG(TEST_DAG_ID, default_args=args) - self.dag = dag - - @mock.patch('requests.Session', FakeSession) - def test_get(self): +class TestHttpOpSensor: + def test_get(self, httpx_mock, setup_op_dag): + httpx_mock.add_response(status_code=200) op = SimpleHttpOperator( task_id='get_op', method='GET', endpoint='/search', data={"client": "ubuntu", "q": "airflow"}, headers={}, - dag=self.dag, + dag=setup_op_dag, ) op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) - @mock.patch('requests.Session', FakeSession) - def test_get_response_check(self): + def test_get_response_check(self, httpx_mock, setup_op_dag): + httpx_mock.add_response(status_code=200, data="apache/airflow") op = SimpleHttpOperator( task_id='get_op', method='GET', endpoint='/search', - data={"client": "ubuntu", "q": "airflow"}, response_check=lambda response: ("apache/airflow" in response.text), headers={}, - dag=self.dag, + dag=setup_op_dag, ) op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) - @mock.patch('requests.Session', FakeSession) - def test_sensor(self): + def test_sensor(self, httpx_mock, setup_op_dag): + httpx_mock.add_response( + status_code=200, + url="https://www.httpbin.org//search?client=ubuntu&q=airflow&date=" + + DEFAULT_DATE.strftime('%Y-%m-%d'), + ) sensor = HttpSensor( task_id='http_sensor_check', http_conn_id='http_default', endpoint='/search', request_params={"client": "ubuntu", "q": "airflow", 'date': '{{ds}}'}, headers={}, - response_check=lambda response: ( - "apache/airflow/" + DEFAULT_DATE.strftime('%Y-%m-%d') in response.text - ), poke_interval=5, timeout=15, - dag=self.dag, + dag=setup_op_dag, ) sensor.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) diff --git a/tests/providers/opsgenie/hooks/test_opsgenie_alert.py b/tests/providers/opsgenie/hooks/test_opsgenie_alert.py index 0db9ca4cdbdce..04743c93cd5db 100644 --- a/tests/providers/opsgenie/hooks/test_opsgenie_alert.py +++ b/tests/providers/opsgenie/hooks/test_opsgenie_alert.py @@ -20,7 +20,6 @@ import unittest import pytest -import requests_mock from airflow.exceptions import AirflowException from airflow.models import Connection @@ -31,6 +30,36 @@ class TestOpsgenieAlertHook(unittest.TestCase): conn_id = 'opsgenie_conn_id_test' opsgenie_alert_endpoint = 'https://api.opsgenie.com/v2/alerts' + + def setUp(self): + db.merge_conn( + Connection( + conn_id=self.conn_id, + conn_type='http', + host='https://api.opsgenie.com/', + password='eb243592-faa2-4ba2-a551q-1afdf565c889', + ) + ) + + def test_get_api_key(self): + hook = OpsgenieAlertHook(opsgenie_conn_id=self.conn_id) + api_key = hook._get_api_key() + assert 'eb243592-faa2-4ba2-a551q-1afdf565c889' == api_key + + def test_get_conn_defaults_host(self): + hook = OpsgenieAlertHook() + hook.get_conn() + assert 'https://api.opsgenie.com' == hook.base_url + + +class TestOpsGenieAlertMockHttpx: + conn_id = 'opsgenie_conn_id_test' + opsgenie_alert_endpoint = 'https://api.opsgenie.com/v2/alerts' + _mock_success_response_body = { + "result": "Request will be processed", + "took": 0.302, + "requestId": "43a29c5c-3dbf-4fa4-9c26-f4f71023e120", + } _payload = { 'message': 'An example alert message', 'alias': 'Life is too short for no alias', @@ -60,57 +89,35 @@ class TestOpsgenieAlertHook(unittest.TestCase): 'user': 'Jesse', 'note': 'Write this down', } - _mock_success_response_body = { - "result": "Request will be processed", - "took": 0.302, - "requestId": "43a29c5c-3dbf-4fa4-9c26-f4f71023e120", - } - - def setUp(self): - db.merge_conn( - Connection( - conn_id=self.conn_id, - conn_type='http', - host='https://api.opsgenie.com/', - password='eb243592-faa2-4ba2-a551q-1afdf565c889', - ) - ) - def test_get_api_key(self): - hook = OpsgenieAlertHook(opsgenie_conn_id=self.conn_id) - api_key = hook._get_api_key() - assert 'eb243592-faa2-4ba2-a551q-1afdf565c889' == api_key - - def test_get_conn_defaults_host(self): - hook = OpsgenieAlertHook() - hook.get_conn() - assert 'https://api.opsgenie.com' == hook.base_url - - @requests_mock.mock() - def test_call_with_success(self, m): + def test_call_with_success(self, httpx_mock): hook = OpsgenieAlertHook(opsgenie_conn_id=self.conn_id) - m.post(self.opsgenie_alert_endpoint, status_code=202, json=self._mock_success_response_body) + httpx_mock.add_response( + url=self.opsgenie_alert_endpoint, status_code=202, json=self._mock_success_response_body + ) resp = hook.execute(payload=self._payload) assert resp.status_code == 202 assert resp.json() == self._mock_success_response_body - @requests_mock.mock() - def test_api_key_set(self, m): + def test_api_key_set(self, httpx_mock): hook = OpsgenieAlertHook(opsgenie_conn_id=self.conn_id) - m.post(self.opsgenie_alert_endpoint, status_code=202, json=self._mock_success_response_body) + httpx_mock.add_response( + url=self.opsgenie_alert_endpoint, status_code=202, json=self._mock_success_response_body + ) resp = hook.execute(payload=self._payload) assert resp.request.headers.get('Authorization') == 'GenieKey eb243592-faa2-4ba2-a551q-1afdf565c889' - @requests_mock.mock() - def test_api_key_not_set(self, m): + def test_api_key_not_set(self): hook = OpsgenieAlertHook() - m.post(self.opsgenie_alert_endpoint, status_code=202, json=self._mock_success_response_body) with pytest.raises(AirflowException): hook.execute(payload=self._payload) - @requests_mock.mock() - def test_payload_set(self, m): + def test_payload_set(self, httpx_mock): hook = OpsgenieAlertHook(opsgenie_conn_id=self.conn_id) - m.post(self.opsgenie_alert_endpoint, status_code=202, json=self._mock_success_response_body) - resp = hook.execute(payload=self._payload) - assert json.loads(resp.request.body) == self._payload + httpx_mock.add_response( + url=self.opsgenie_alert_endpoint, + status_code=202, + json=self._mock_success_response_body, + match_content=json.dumps(self._payload).encode("utf-8"), + ) + hook.execute(payload=self._payload) diff --git a/tests/providers/slack/hooks/test_slack_webhook.py b/tests/providers/slack/hooks/test_slack_webhook.py index 6fce527a76e23..92603b4388a69 100644 --- a/tests/providers/slack/hooks/test_slack_webhook.py +++ b/tests/providers/slack/hooks/test_slack_webhook.py @@ -18,9 +18,6 @@ # import json import unittest -from unittest import mock - -from requests.exceptions import MissingSchema from airflow.models import Connection from airflow.providers.slack.hooks.slack_webhook import SlackWebhookHook @@ -129,41 +126,22 @@ def test_build_slack_message(self): # Then assert self.expected_message_dict == json.loads(message) - @mock.patch('requests.Session') - @mock.patch('requests.Request') - def test_url_generated_by_http_conn_id(self, mock_request, mock_session): + +class TestSlackWebhookMockHttpx: + expected_url = 'https://hooks.slack.com/services/T000/B000/XXX' + expected_method = 'POST' + + def test_url_generated_by_http_conn_id(self, httpx_mock): + httpx_mock.add_response(url=self.expected_url, method=self.expected_method) hook = SlackWebhookHook(http_conn_id='slack-webhook-url') - try: - hook.execute() - except MissingSchema: - pass - mock_request.assert_called_once_with( - self.expected_method, self.expected_url, headers=mock.ANY, data=mock.ANY - ) - mock_request.reset_mock() + hook.execute() - @mock.patch('requests.Session') - @mock.patch('requests.Request') - def test_url_generated_by_endpoint(self, mock_request, mock_session): + def test_url_generated_by_endpoint(self, httpx_mock): + httpx_mock.add_response(url=self.expected_url, method=self.expected_method) hook = SlackWebhookHook(webhook_token=self.expected_url) - try: - hook.execute() - except MissingSchema: - pass - mock_request.assert_called_once_with( - self.expected_method, self.expected_url, headers=mock.ANY, data=mock.ANY - ) - mock_request.reset_mock() + hook.execute() - @mock.patch('requests.Session') - @mock.patch('requests.Request') - def test_url_generated_by_http_conn_id_and_endpoint(self, mock_request, mock_session): + def test_url_generated_by_http_conn_id_and_endpoint(self, httpx_mock): + httpx_mock.add_response(url=self.expected_url, method=self.expected_method) hook = SlackWebhookHook(http_conn_id='slack-webhook-host', webhook_token='B000/XXX') - try: - hook.execute() - except MissingSchema: - pass - mock_request.assert_called_once_with( - self.expected_method, self.expected_url, headers=mock.ANY, data=mock.ANY - ) - mock_request.reset_mock() + hook.execute()