diff --git a/airflow/providers/airbyte/CHANGELOG.rst b/airflow/providers/airbyte/CHANGELOG.rst
index cef7dda80708a..24491fb7decc0 100644
--- a/airflow/providers/airbyte/CHANGELOG.rst
+++ b/airflow/providers/airbyte/CHANGELOG.rst
@@ -19,6 +19,17 @@
Changelog
---------
+2.0.0
+.....
+
+Breaking changes
+~~~~~~~~~~~~~~~~
+
+The Airbyte Hook derives now from the new version of Hook in Http Provider.
+
+TODO: Add extra dependency to 2.0.0 version of http Provider.
+
+
1.0.0
.....
diff --git a/airflow/providers/airbyte/provider.yaml b/airflow/providers/airbyte/provider.yaml
index 77b109f45058d..08b44be18c815 100644
--- a/airflow/providers/airbyte/provider.yaml
+++ b/airflow/providers/airbyte/provider.yaml
@@ -22,6 +22,7 @@ description: |
`Airbyte `__
versions:
+ - 2.0.0
- 1.0.0
integrations:
diff --git a/airflow/providers/apache/livy/CHANGELOG.rst b/airflow/providers/apache/livy/CHANGELOG.rst
index d43fe14754f4c..aaa65d7d29c17 100644
--- a/airflow/providers/apache/livy/CHANGELOG.rst
+++ b/airflow/providers/apache/livy/CHANGELOG.rst
@@ -19,6 +19,14 @@
Changelog
---------
+2.0.0
+.....
+
+The Livy Hook derives now from the new version of Http Hook in Http Provider.
+
+TODO: Add extra dependency to 2.0.0 version of http Provider.
+
+
1.1.0
.....
diff --git a/airflow/providers/apache/livy/hooks/livy.py b/airflow/providers/apache/livy/hooks/livy.py
index 75d08af56883f..0c4d6e972da4d 100644
--- a/airflow/providers/apache/livy/hooks/livy.py
+++ b/airflow/providers/apache/livy/hooks/livy.py
@@ -21,7 +21,7 @@
from enum import Enum
from typing import Any, Dict, List, Optional, Sequence, Union
-import requests
+import httpx
from airflow.exceptions import AirflowException
from airflow.providers.http.hooks.http import HttpHook
@@ -75,19 +75,25 @@ def __init__(
super().__init__(http_conn_id=livy_conn_id)
self.extra_options = extra_options or {}
- def get_conn(self, headers: Optional[Dict[str, Any]] = None) -> Any:
+ def get_conn(
+ self, headers: Optional[Dict[Any, Any]] = None, verify: bool = True, proxies=None, cert=None
+ ) -> httpx.Client:
"""
Returns http session for use with requests
:param headers: additional headers to be passed through as a dictionary
:type headers: dict
- :return: requests session
- :rtype: requests.Session
+ :param verify: whether to verify SSL during the connection (only use for testing)
+ :param proxies: A dictionary mapping proxy keys to proxy
+ :param cert: client An SSL certificate used by the requested host
+ to authenticate the client. Either a path to an SSL certificate file, or
+ two-tuple of (certificate file, key file), or a three-tuple of (certificate
+ file, key file, password).
"""
tmp_headers = self._def_headers.copy() # setting default headers
if headers:
tmp_headers.update(headers)
- return super().get_conn(tmp_headers)
+ return super().get_conn(tmp_headers, verify, proxies, cert)
def run_method(
self,
@@ -108,7 +114,7 @@ def run_method(
:param headers: headers
:type headers: dict
:return: http response
- :rtype: requests.Response
+ :rtype: httpx.Response
"""
if method not in ('GET', 'POST', 'PUT', 'DELETE', 'HEAD'):
raise ValueError(f"Invalid http method '{method}'")
@@ -142,7 +148,7 @@ def post_batch(self, *args: Any, **kwargs: Any) -> Any:
try:
response.raise_for_status()
- except requests.exceptions.HTTPError as err:
+ except httpx.HTTPStatusError as err:
raise AirflowException(
"Could not submit batch. Status code: {}. Message: '{}'".format(
err.response.status_code, err.response.text
@@ -172,7 +178,7 @@ def get_batch(self, session_id: Union[int, str]) -> Any:
try:
response.raise_for_status()
- except requests.exceptions.HTTPError as err:
+ except httpx.HTTPStatusError as err:
self.log.warning("Got status code %d for session %d", err.response.status_code, session_id)
raise AirflowException(
f"Unable to fetch batch with id: {session_id}. Message: {err.response.text}"
@@ -196,7 +202,7 @@ def get_batch_state(self, session_id: Union[int, str]) -> BatchState:
try:
response.raise_for_status()
- except requests.exceptions.HTTPError as err:
+ except httpx.HTTPStatusError as err:
self.log.warning("Got status code %d for session %d", err.response.status_code, session_id)
raise AirflowException(
f"Unable to fetch batch with id: {session_id}. Message: {err.response.text}"
@@ -223,7 +229,7 @@ def delete_batch(self, session_id: Union[int, str]) -> Any:
try:
response.raise_for_status()
- except requests.exceptions.HTTPError as err:
+ except httpx.HTTPStatusError as err:
self.log.warning("Got status code %d for session %d", err.response.status_code, session_id)
raise AirflowException(
"Could not kill the batch with session id: {}. Message: {}".format(
diff --git a/airflow/providers/apache/livy/provider.yaml b/airflow/providers/apache/livy/provider.yaml
index 309c2f538aa1a..f6999fa4f7e5c 100644
--- a/airflow/providers/apache/livy/provider.yaml
+++ b/airflow/providers/apache/livy/provider.yaml
@@ -22,6 +22,7 @@ description: |
`Apache Livy `__
versions:
+ - 2.0.0
- 1.1.0
- 1.0.1
- 1.0.0
diff --git a/airflow/providers/http/CHANGELOG.rst b/airflow/providers/http/CHANGELOG.rst
index b4a0313336731..dd3f98d3786c8 100644
--- a/airflow/providers/http/CHANGELOG.rst
+++ b/airflow/providers/http/CHANGELOG.rst
@@ -19,6 +19,36 @@
Changelog
---------
+
+2.0.0
+.....
+
+Breaking changes
+~~~~~~~~~~~~~~~~
+
+Due to licencing issues, the HTTP provider switched from ``requests`` to ``httpx`` library.
+In case you use an authentication method different than default Basic Authenticaton,
+you will need to change it to use the httpx-compatible one.
+
+HttpHook's run() method passes any kwargs passed to underlying httpx.Requests object rather than
+to requests.Requests. They largely compatible (for example json kwarg is supported in both) but not fully.
+The ``content`` and ``auth`` parameters are not supported in ``httpx``.
+
+The "verify", "proxie" and "cert" extra parameters
+
+The get_conn() method of HttpHook returns ``httpx.Client`` rather than ``requests.Session``.
+
+The get_conn() method of HttpHook has extra parameters: verify ,proxies and cert which are executed by the
+run() method, so if your hook derives from the HttpHook it should be updated.
+
+The ``run_and_check`` method is gone. PreparedRequests are not supported in ``httpx`` and this method
+used it. Instead, run method simply creates and executes the full request instance.
+
+While ``httpx`` does not support ``REQUESTS_CA_BUNDLE`` variable overriding ca crt, we backported
+the requests behaviour to HTTPHook and it respects the variable if set. More details on that variable
+can be found `here None:
super().__init__()
self.http_conn_id = http_conn_id
self.method = method.upper()
self.base_url: str = ""
self._retry_obj: Callable[..., Any]
- self.auth_type: Any = auth_type
+ self.auth_type: Type[Auth] = auth_type
+ self.extra_headers: Dict = {}
+ self.conn = None
# headers may be passed through directly or in the "extra" field in the connection
# definition
- def get_conn(self, headers: Optional[Dict[Any, Any]] = None) -> requests.Session:
+ def get_conn(
+ self, headers: Optional[Dict[Any, Any]] = None, verify: bool = True, proxies=None, cert=None
+ ) -> httpx.Client:
"""
- Returns http session for use with requests
+ Returns httpx client for use with requests
:param headers: additional headers to be passed through as a dictionary
:type headers: dict
+ :param verify: whether to verify SSL during the connection (only use for testing)
+ :param proxies: A dictionary mapping proxy keys to proxy
+ :param cert: client An SSL certificate used by the requested host
+ to authenticate the client. Either a path to an SSL certificate file, or
+ two-tuple of (certificate file, key file), or a three-tuple of (certificate
+ file, key file, password).
"""
- session = requests.Session()
+ client = httpx.Client(verify=verify, proxies=proxies, cert=cert)
if self.http_conn_id:
conn = self.get_connection(self.http_conn_id)
-
if conn.host and "://" in conn.host:
self.base_url = conn.host
else:
@@ -82,16 +93,20 @@ def get_conn(self, headers: Optional[Dict[Any, Any]] = None) -> requests.Session
if conn.port:
self.base_url = self.base_url + ":" + str(conn.port)
if conn.login:
- session.auth = self.auth_type(conn.login, conn.password)
+ # Note! This handles Basic Auth and DigestAuth and any other authentication that
+ # supports login/password in the constructor.
+ client.auth = self.auth_type(conn.login, conn.password) # noqa
if conn.extra:
try:
- session.headers.update(conn.extra_dejson)
+ client.headers.update(conn.extra_dejson)
except TypeError:
self.log.warning('Connection to %s has invalid extra field.', conn.host)
+ # Hooks deriving from HTTP Hook might use it to query connection details
+ self.conn = conn
if headers:
- session.headers.update(headers)
+ client.headers.update(headers)
- return session
+ return client
def run(
self,
@@ -111,16 +126,24 @@ def run(
:param headers: additional headers to be passed through as a dictionary
:type headers: dict
:param extra_options: additional options to be used when executing the request
- i.e. {'check_response': False} to avoid checking raising exceptions on non
- 2XX or 3XX status codes
+ i.e. ``{'check_response': False}`` to avoid checking raising exceptions on non
+ 2XX or 3XX status codes. The extra options can take the following keys:
+ verify, proxies, cert, stream, allow_redirects, timeout, check_response.
+ See ``httpx.Client`` for description of those parameters.
:type extra_options: dict
:param request_kwargs: Additional kwargs to pass when creating a request.
- For example, ``run(json=obj)`` is passed as ``requests.Request(json=obj)``
+ For example, ``run(json=obj)`` is passed as ``httpx.Request(json=obj)``
"""
extra_options = extra_options or {}
- session = self.get_conn(headers)
-
+ verify = extra_options.get("verify", True)
+ if verify is True:
+ # Only use REQUESTS_CA_BUNDLE content if verify is set to True,
+ # otherwise use passed value as it can be string, or SSLcontext
+ verify = os.environ.get('REQUESTS_CA_BUNDLE', verify)
+ proxies = extra_options.get("proxies", None)
+ cert = extra_options.get("cert", None)
+ client = self.get_conn(headers, verify=verify, proxies=proxies, cert=cert)
if self.base_url and not self.base_url.endswith('/') and endpoint and not endpoint.startswith('/'):
url = self.base_url + '/' + endpoint
else:
@@ -128,19 +151,33 @@ def run(
if self.method == 'GET':
# GET uses params
- req = requests.Request(self.method, url, params=data, headers=headers, **request_kwargs)
+ req = httpx.Request(self.method, url, params=data, headers=client.headers, **request_kwargs)
elif self.method == 'HEAD':
# HEAD doesn't use params
- req = requests.Request(self.method, url, headers=headers, **request_kwargs)
+ req = httpx.Request(self.method, url, headers=client.headers, **request_kwargs)
else:
# Others use data
- req = requests.Request(self.method, url, data=data, headers=headers, **request_kwargs)
+ req = httpx.Request(self.method, url, headers=client.headers, data=data, **request_kwargs)
- prepped_request = session.prepare_request(req)
+ # Send the request
+ send_kwargs = {
+ "stream": extra_options.get("stream", False),
+ "allow_redirects": extra_options.get("allow_redirects", True),
+ "timeout": extra_options.get("timeout"),
+ }
self.log.info("Sending '%s' to url: %s", self.method, url)
- return self.run_and_check(session, prepped_request, extra_options)
- def check_response(self, response: requests.Response) -> None:
+ try:
+ response = client.send(req, **send_kwargs)
+ if extra_options.get('check_response', True):
+ self.check_response(response)
+ return response
+
+ except httpx.NetworkError as ex:
+ self.log.warning('%s Tenacity will retry to execute the operation', ex)
+ raise ex
+
+ def check_response(self, response: httpx.Response) -> None:
"""
Checks the status code and raise an AirflowException exception on non 2XX or 3XX
status codes
@@ -150,57 +187,15 @@ def check_response(self, response: requests.Response) -> None:
"""
try:
response.raise_for_status()
- except requests.exceptions.HTTPError:
- self.log.error("HTTP error: %s", response.reason)
+ except httpx.HTTPError:
+ phrase = 'Unknown'
+ try:
+ phrase = http.HTTPStatus(response.status_code).phrase
+ except ValueError:
+ pass
+ self.log.error("HTTP error: %s", phrase)
self.log.error(response.text)
- raise AirflowException(str(response.status_code) + ":" + response.reason)
-
- def run_and_check(
- self,
- session: requests.Session,
- prepped_request: requests.PreparedRequest,
- extra_options: Dict[Any, Any],
- ) -> Any:
- """
- Grabs extra options like timeout and actually runs the request,
- checking for the result
-
- :param session: the session to be used to execute the request
- :type session: requests.Session
- :param prepped_request: the prepared request generated in run()
- :type prepped_request: session.prepare_request
- :param extra_options: additional options to be used when executing the request
- i.e. {'check_response': False} to avoid checking raising exceptions on non 2XX
- or 3XX status codes
- :type extra_options: dict
- """
- extra_options = extra_options or {}
-
- settings = session.merge_environment_settings(
- prepped_request.url,
- proxies=extra_options.get("proxies", {}),
- stream=extra_options.get("stream", False),
- verify=extra_options.get("verify"),
- cert=extra_options.get("cert"),
- )
-
- # Send the request.
- send_kwargs = {
- "timeout": extra_options.get("timeout"),
- "allow_redirects": extra_options.get("allow_redirects", True),
- }
- send_kwargs.update(settings)
-
- try:
- response = session.send(prepped_request, **send_kwargs)
-
- if extra_options.get('check_response', True):
- self.check_response(response)
- return response
-
- except requests.exceptions.ConnectionError as ex:
- self.log.warning('%s Tenacity will retry to execute the operation', ex)
- raise ex
+ raise AirflowException(str(response.status_code) + ":" + response.text)
def run_with_advanced_retry(self, _retry_args: Dict[Any, Any], *args: Any, **kwargs: Any) -> Any:
"""
diff --git a/airflow/providers/http/operators/http.py b/airflow/providers/http/operators/http.py
index b6295185d8ba8..fa04cb56a3297 100644
--- a/airflow/providers/http/operators/http.py
+++ b/airflow/providers/http/operators/http.py
@@ -1,4 +1,3 @@
-#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
@@ -15,9 +14,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+
from typing import Any, Callable, Dict, Optional, Type
-from requests.auth import AuthBase, HTTPBasicAuth
+from httpx import Auth, BasicAuth
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
@@ -86,7 +86,7 @@ def __init__(
extra_options: Optional[Dict[str, Any]] = None,
http_conn_id: str = 'http_default',
log_response: bool = False,
- auth_type: Type[AuthBase] = HTTPBasicAuth,
+ auth_type: Type[Auth] = BasicAuth,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
diff --git a/airflow/providers/http/provider.yaml b/airflow/providers/http/provider.yaml
index 79a095b0a4417..cf1145b5cb6be 100644
--- a/airflow/providers/http/provider.yaml
+++ b/airflow/providers/http/provider.yaml
@@ -22,6 +22,7 @@ description: |
`Hypertext Transfer Protocol (HTTP) `__
versions:
+ - 2.0.0
- 1.1.1
- 1.1.0
- 1.0.0
diff --git a/airflow/providers/opsgenie/CHANGELOG.rst b/airflow/providers/opsgenie/CHANGELOG.rst
index dd5b1de2f0f31..bae90a5d42015 100644
--- a/airflow/providers/opsgenie/CHANGELOG.rst
+++ b/airflow/providers/opsgenie/CHANGELOG.rst
@@ -19,6 +19,16 @@
Changelog
---------
+2.0.0
+.....
+
+Breaking changes
+~~~~~~~~~~~~~~~~
+
+The OpsGenie Hook derives now from the new version of Http Hook in Http Provider.
+
+TODO: Add extra dependency to 2.0.0 version of http Provider.
+
1.0.2
.....
diff --git a/airflow/providers/opsgenie/hooks/opsgenie_alert.py b/airflow/providers/opsgenie/hooks/opsgenie_alert.py
index 60f1734fdccbc..4e1ff796198f3 100644
--- a/airflow/providers/opsgenie/hooks/opsgenie_alert.py
+++ b/airflow/providers/opsgenie/hooks/opsgenie_alert.py
@@ -18,9 +18,9 @@
#
import json
-from typing import Any, Optional
+from typing import Any, Dict, Optional
-import requests
+import httpx
from airflow.exceptions import AirflowException
from airflow.providers.http.hooks.http import HttpHook
@@ -54,20 +54,25 @@ def _get_api_key(self) -> str:
)
return api_key
- def get_conn(self, headers: Optional[dict] = None) -> requests.Session:
+ def get_conn(
+ self, headers: Optional[Dict[Any, Any]] = None, verify: bool = True, proxies=None, cert=None
+ ) -> httpx.Client:
"""
Overwrite HttpHook get_conn because this hook just needs base_url
and headers, and does not need generic params
:param headers: additional headers to be passed through as a dictionary
:type headers: dict
+ :param verify: whether to verify SSL during the connection (only use for testing)
+ :param proxies: A dictionary mapping proxy keys to proxy
+ :param cert: client An SSL certificate used by the requested host
+ to authenticate the client. Either a path to an SSL certificate file, or
+ two-tuple of (certificate file, key file), or a three-tuple of (certificate
+ file, key file, password).
"""
- conn = self.get_connection(self.http_conn_id)
- self.base_url = conn.host if conn.host else 'https://api.opsgenie.com'
- session = requests.Session()
- if headers:
- session.headers.update(headers)
- return session
+ client = super().get_conn(headers=headers, verify=verify, proxies=proxies, cert=cert)
+ self.base_url = self.conn.host if self.conn.host else 'https://api.opsgenie.com'
+ return client
def execute(self, payload: Optional[dict] = None) -> Any:
"""
diff --git a/airflow/providers/opsgenie/provider.yaml b/airflow/providers/opsgenie/provider.yaml
index c2e8d22546d10..f2a18ac6643b2 100644
--- a/airflow/providers/opsgenie/provider.yaml
+++ b/airflow/providers/opsgenie/provider.yaml
@@ -22,6 +22,7 @@ description: |
`Opsgenie `__
versions:
+ - 2.0.0
- 1.0.2
- 1.0.1
- 1.0.0
diff --git a/setup.py b/setup.py
index b057f51ec529c..61fe38635029a 100644
--- a/setup.py
+++ b/setup.py
@@ -343,20 +343,7 @@ def get_sphinx_theme_version() -> str:
'pyhive[hive]>=0.6.0',
'thrift>=0.9.2',
]
-http = [
- 'requests>=2.20.0',
-]
-http_provider = [
- # NOTE ! The HTTP provider is NOT preinstalled by default when Airflow is installed - because it
- # depends on `requests` library and until `chardet` is mandatory dependency of `requests`
- # See https://github.com/psf/requests/pull/5797
- # This means that providers that depend on Http and cannot work without it, have to have
- # explicit dependency on `apache-airflow-providers-http` which needs to be pulled in for them.
- # Other cross-provider-dependencies are optional (usually cross-provider dependencies only enable
- # some features of providers and majority of those providers works). They result with an extra,
- # not with the `install-requires` dependency.
- 'apache-airflow-providers-http',
-]
+http = []
jdbc = [
'jaydebeapi>=1.1.1',
]
@@ -543,7 +530,7 @@ def get_sphinx_theme_version() -> str:
# Dict of all providers which are part of the Apache Airflow repository together with their requirements
PROVIDERS_REQUIREMENTS: Dict[str, List[str]] = {
- 'airbyte': http_provider,
+ 'airbyte': http,
'amazon': amazon,
'apache.beam': apache_beam,
'apache.cassandra': cassandra,
@@ -551,7 +538,7 @@ def get_sphinx_theme_version() -> str:
'apache.hdfs': hdfs,
'apache.hive': hive,
'apache.kylin': kylin,
- 'apache.livy': http_provider,
+ 'apache.livy': http,
'apache.pig': [],
'apache.pinot': pinot,
'apache.spark': spark,
@@ -584,7 +571,7 @@ def get_sphinx_theme_version() -> str:
'neo4j': neo4j,
'odbc': odbc,
'openfaas': [],
- 'opsgenie': http_provider,
+ 'opsgenie': http,
'oracle': oracle,
'pagerduty': pagerduty,
'papermill': papermill,
diff --git a/tests/providers/airbyte/hooks/test_airbyte.py b/tests/providers/airbyte/hooks/test_airbyte.py
index 09f10beffc255..45ba2b107c2d5 100644
--- a/tests/providers/airbyte/hooks/test_airbyte.py
+++ b/tests/providers/airbyte/hooks/test_airbyte.py
@@ -20,7 +20,6 @@
from unittest import mock
import pytest
-import requests_mock
from airflow.exceptions import AirflowException
from airflow.models import Connection
@@ -34,12 +33,7 @@ class TestAirbyteHook(unittest.TestCase):
"""
airbyte_conn_id = 'airbyte_conn_id_test'
- connection_id = 'conn_test_sync'
- job_id = 1
- sync_connection_endpoint = 'http://test-airbyte:8001/api/v1/connections/sync'
- get_job_endpoint = 'http://test-airbyte:8001/api/v1/jobs/get'
- _mock_sync_conn_success_response_body = {'job': {'id': 1}}
- _mock_job_status_success_response_body = {'job': {'status': 'succeeded'}}
+ job_id = '1'
def setUp(self):
db.merge_conn(
@@ -54,22 +48,6 @@ def return_value_get_job(self, status):
response.json.return_value = {'job': {'status': status}}
return response
- @requests_mock.mock()
- def test_submit_sync_connection(self, m):
- m.post(
- self.sync_connection_endpoint, status_code=200, json=self._mock_sync_conn_success_response_body
- )
- resp = self.hook.submit_sync_connection(connection_id=self.connection_id)
- assert resp.status_code == 200
- assert resp.json() == self._mock_sync_conn_success_response_body
-
- @requests_mock.mock()
- def test_get_job_status(self, m):
- m.post(self.get_job_endpoint, status_code=200, json=self._mock_job_status_success_response_body)
- resp = self.hook.get_job(job_id=self.job_id)
- assert resp.status_code == 200
- assert resp.json() == self._mock_job_status_success_response_body
-
@mock.patch('airflow.providers.airbyte.hooks.airbyte.AirbyteHook.get_job')
def test_wait_for_job_succeeded(self, mock_get_job):
mock_get_job.side_effect = [self.return_value_get_job(self.hook.SUCCEEDED)]
@@ -124,3 +102,38 @@ def test_wait_for_job_cancelled(self, mock_get_job):
calls = [mock.call(job_id=self.job_id), mock.call(job_id=self.job_id)]
assert mock_get_job.has_calls(calls)
+
+
+@pytest.fixture
+def setup_hook():
+ yield AirbyteHook(airbyte_conn_id='airbyte_conn_id_test')
+
+
+class TestAirbyteMockHttpx:
+ sync_connection_endpoint = 'http://test-airbyte:8001/api/v1/connections/sync'
+ get_job_endpoint = 'http://test-airbyte:8001/api/v1/jobs/get'
+ connection_id = 'conn_test_sync'
+ _mock_sync_conn_success_response_body = {'job': {'id': 1}}
+ _mock_job_status_success_response_body = {'job': {'status': 'succeeded'}}
+
+ def test_submit_sync_connection(self, httpx_mock, setup_hook):
+ httpx_mock.add_response(
+ method='POST',
+ url=self.sync_connection_endpoint,
+ status_code=200,
+ json=self._mock_sync_conn_success_response_body,
+ )
+ resp = setup_hook.submit_sync_connection(connection_id=self.connection_id)
+ assert resp.status_code == 200
+ assert resp.json() == self._mock_sync_conn_success_response_body
+
+ def test_get_job_status(self, httpx_mock, setup_hook):
+ httpx_mock.add_response(
+ method='POST',
+ url=self.get_job_endpoint,
+ status_code=200,
+ json=self._mock_job_status_success_response_body,
+ )
+ resp = setup_hook.get_job(job_id='1')
+ assert resp.status_code == 200
+ assert resp.json() == self._mock_job_status_success_response_body
diff --git a/tests/providers/apache/livy/hooks/test_livy.py b/tests/providers/apache/livy/hooks/test_livy.py
index 86b7acaa6fc39..ed9b6d27f66f8 100644
--- a/tests/providers/apache/livy/hooks/test_livy.py
+++ b/tests/providers/apache/livy/hooks/test_livy.py
@@ -14,14 +14,13 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+# pylint: disable=unused-argument
import json
import unittest
-from unittest.mock import patch
+import httpx
import pytest
-import requests_mock
-from requests.exceptions import RequestException
from airflow.exceptions import AirflowException
from airflow.models import Connection
@@ -33,39 +32,6 @@
class TestLivyHook(unittest.TestCase):
- @classmethod
- def setUpClass(cls):
- db.merge_conn(
- Connection(conn_id='livy_default', conn_type='http', host='host', schema='http', port=8998)
- )
- db.merge_conn(Connection(conn_id='default_port', conn_type='http', host='http://host'))
- db.merge_conn(Connection(conn_id='default_protocol', conn_type='http', host='host'))
- db.merge_conn(Connection(conn_id='port_set', host='host', conn_type='http', port=1234))
- db.merge_conn(Connection(conn_id='schema_set', host='host', conn_type='http', schema='zzz'))
- db.merge_conn(
- Connection(conn_id='dont_override_schema', conn_type='http', host='http://host', schema='zzz')
- )
- db.merge_conn(Connection(conn_id='missing_host', conn_type='http', port=1234))
- db.merge_conn(Connection(conn_id='invalid_uri', uri='http://invalid_uri:4321'))
-
- def test_build_get_hook(self):
-
- connection_url_mapping = {
- # id, expected
- 'default_port': 'http://host',
- 'default_protocol': 'http://host',
- 'port_set': 'http://host:1234',
- 'schema_set': 'zzz://host',
- 'dont_override_schema': 'http://host',
- }
-
- for conn_id, expected in connection_url_mapping.items():
- with self.subTest(conn_id):
- hook = LivyHook(livy_conn_id=conn_id)
-
- hook.get_conn()
- assert hook.base_url == expected
-
@unittest.skip("inherited HttpHook does not handle missing hostname")
def test_missing_host(self):
with pytest.raises(AirflowException):
@@ -241,35 +207,81 @@ def test_validate_extra_conf(self):
with pytest.raises(ValueError):
LivyHook._validate_extra_conf({'has_val': 'val', 'no_val': ''})
- @patch('airflow.providers.apache.livy.hooks.livy.LivyHook.run_method')
- def test_post_batch_arguments(self, mock_request):
+ def test_check_session_id(self):
+ with self.subTest('valid 00'):
+ try:
+ LivyHook._validate_session_id(100)
+ except TypeError:
+ self.fail("")
- mock_request.return_value.status_code = 201
- mock_request.return_value.json.return_value = {
- 'id': BATCH_ID,
- 'state': BatchState.STARTING.value,
- 'log': [],
- }
+ with self.subTest('valid 01'):
+ try:
+ LivyHook._validate_session_id(0)
+ except TypeError:
+ self.fail("")
- hook = LivyHook()
- resp = hook.post_batch(file='sparkapp')
+ with self.subTest('None'):
+ with pytest.raises(TypeError):
+ LivyHook._validate_session_id(None) # noqa
- mock_request.assert_called_once_with(
- method='POST', endpoint='/batches', data=json.dumps({'file': 'sparkapp'})
- )
+ with self.subTest('random string'):
+ with pytest.raises(TypeError):
+ LivyHook._validate_session_id('asd')
+
+
+@pytest.fixture(scope="class")
+def setup_connections():
+ db.merge_conn(Connection(conn_id='livy_default', conn_type='http', host='host', schema='http', port=8998))
+ db.merge_conn(Connection(conn_id='default_port', conn_type='http', host='http://host'))
+ db.merge_conn(Connection(conn_id='default_protocol', conn_type='http', host='host'))
+ db.merge_conn(Connection(conn_id='port_set', host='host', conn_type='http', port=1234))
+ db.merge_conn(Connection(conn_id='schema_set', host='host', conn_type='http', schema='zzz'))
+ db.merge_conn(
+ Connection(conn_id='dont_override_schema', conn_type='http', host='http://host', schema='zzz')
+ )
+ db.merge_conn(Connection(conn_id='missing_host', conn_type='http', port=1234))
+ db.merge_conn(Connection(conn_id='invalid_uri', uri='http://invalid_uri:4321'))
- request_args = mock_request.call_args[1]
- assert 'data' in request_args
- assert isinstance(request_args['data'], str)
+class TestLivyMockHttpx:
+ def test_build_get_hook(self, setup_connections):
+
+ connection_url_mapping = {
+ # id, expected
+ 'default_port': 'http://host',
+ 'default_protocol': 'http://host',
+ 'port_set': 'http://host:1234',
+ 'schema_set': 'zzz://host',
+ 'dont_override_schema': 'http://host',
+ }
+
+ for conn_id, expected in connection_url_mapping.items():
+ hook = LivyHook(livy_conn_id=conn_id)
+
+ hook.get_conn()
+ assert hook.base_url == expected
+
+ def test_post_batch_arguments(self, httpx_mock, setup_connections):
+ httpx_mock.add_response(
+ method='POST',
+ url="http://livy:8998/batches",
+ status_code=201,
+ match_content=json.dumps({'file': 'sparkapp'}).encode("utf-8"),
+ json={
+ 'id': BATCH_ID,
+ 'state': BatchState.STARTING.value,
+ 'log': [],
+ },
+ )
+ hook = LivyHook()
+ resp = hook.post_batch(file='sparkapp')
assert isinstance(resp, int)
assert resp == BATCH_ID
- @requests_mock.mock()
- def test_post_batch_success(self, mock):
- mock.register_uri(
- 'POST',
- '//livy:8998/batches',
+ def test_post_batch_success(self, httpx_mock, setup_connections):
+ httpx_mock.add_response(
+ method='POST',
+ url="http://livy:8998/batches",
json={'id': BATCH_ID, 'state': BatchState.STARTING.value, 'log': []},
status_code=201,
)
@@ -279,17 +291,21 @@ def test_post_batch_success(self, mock):
assert isinstance(resp, int)
assert resp == BATCH_ID
- @requests_mock.mock()
- def test_post_batch_fail(self, mock):
- mock.register_uri('POST', '//livy:8998/batches', json={}, status_code=400, reason='ERROR')
-
+ def test_post_batch_fail(self, httpx_mock, setup_connections):
+ httpx_mock.add_response(
+ method='POST',
+ url="http://livy:8998/batches",
+ json={},
+ status_code=400,
+ )
hook = LivyHook()
with pytest.raises(AirflowException):
hook.post_batch(file='sparkapp')
- @requests_mock.mock()
- def test_get_batch_success(self, mock):
- mock.register_uri('GET', f'//livy:8998/batches/{BATCH_ID}', json={'id': BATCH_ID}, status_code=200)
+ def test_get_batch_success(self, httpx_mock, setup_connections):
+ httpx_mock.add_response(
+ method='GET', url=f'http://livy:8998/batches/{BATCH_ID}', json={'id': BATCH_ID}, status_code=200
+ )
hook = LivyHook()
resp = hook.get_batch(BATCH_ID)
@@ -297,33 +313,28 @@ def test_get_batch_success(self, mock):
assert isinstance(resp, dict)
assert 'id' in resp
- @requests_mock.mock()
- def test_get_batch_fail(self, mock):
- mock.register_uri(
- 'GET',
- f'//livy:8998/batches/{BATCH_ID}',
+ def test_get_batch_fail(self, httpx_mock, setup_connections):
+ httpx_mock.add_response(
+ method='GET',
+ url=f'http://livy:8998/batches/{BATCH_ID}',
json={'msg': 'Unable to find batch'},
status_code=404,
- reason='ERROR',
)
hook = LivyHook()
with pytest.raises(AirflowException):
hook.get_batch(BATCH_ID)
- def test_invalid_uri(self):
+ def test_invalid_uri(self, setup_connections):
hook = LivyHook(livy_conn_id='invalid_uri')
- with pytest.raises(RequestException):
+ with pytest.raises(httpx.ConnectError):
hook.post_batch(file='sparkapp')
- @requests_mock.mock()
- def test_get_batch_state_success(self, mock):
-
+ def test_get_batch_state_success(self, httpx_mock, setup_connections):
running = BatchState.RUNNING
-
- mock.register_uri(
- 'GET',
- f'//livy:8998/batches/{BATCH_ID}/state',
+ httpx_mock.add_response(
+ method='GET',
+ url=f'http://livy:8998/batches/{BATCH_ID}/state',
json={'id': BATCH_ID, 'state': running.value},
status_code=200,
)
@@ -333,116 +344,98 @@ def test_get_batch_state_success(self, mock):
assert isinstance(state, BatchState)
assert state == running
- @requests_mock.mock()
- def test_get_batch_state_fail(self, mock):
- mock.register_uri(
- 'GET', f'//livy:8998/batches/{BATCH_ID}/state', json={}, status_code=400, reason='ERROR'
+ def test_get_batch_state_fail(self, httpx_mock, setup_connections):
+ httpx_mock.add_response(
+ method='GET', url=f'http://livy:8998/batches/{BATCH_ID}/state', json={}, status_code=400
)
hook = LivyHook()
with pytest.raises(AirflowException):
hook.get_batch_state(BATCH_ID)
- @requests_mock.mock()
- def test_get_batch_state_missing(self, mock):
- mock.register_uri('GET', f'//livy:8998/batches/{BATCH_ID}/state', json={}, status_code=200)
+ def test_get_batch_state_missing(self, httpx_mock, setup_connections):
+ httpx_mock.add_response(
+ method='GET', url=f'http://livy:8998/batches/{BATCH_ID}/state', json={}, status_code=200
+ )
hook = LivyHook()
with pytest.raises(AirflowException):
hook.get_batch_state(BATCH_ID)
- def test_parse_post_response(self):
+ def test_parse_post_response(self, setup_connections):
res_id = LivyHook._parse_post_response({'id': BATCH_ID, 'log': []})
assert BATCH_ID == res_id
- @requests_mock.mock()
- def test_delete_batch_success(self, mock):
- mock.register_uri(
- 'DELETE', f'//livy:8998/batches/{BATCH_ID}', json={'msg': 'deleted'}, status_code=200
+ def test_delete_batch_success(self, httpx_mock, setup_connections):
+ httpx_mock.add_response(
+ method='DELETE',
+ url=f'http://livy:8998/batches/{BATCH_ID}',
+ json={'msg': 'deleted'},
+ status_code=200,
)
resp = LivyHook().delete_batch(BATCH_ID)
assert resp == {'msg': 'deleted'}
- @requests_mock.mock()
- def test_delete_batch_fail(self, mock):
- mock.register_uri(
- 'DELETE', f'//livy:8998/batches/{BATCH_ID}', json={}, status_code=400, reason='ERROR'
+ def test_delete_batch_fail(self, httpx_mock, setup_connections):
+ httpx_mock.add_response(
+ method='DELETE', url=f'http://livy:8998/batches/{BATCH_ID}', json={}, status_code=400
)
hook = LivyHook()
with pytest.raises(AirflowException):
hook.delete_batch(BATCH_ID)
- @requests_mock.mock()
- def test_missing_batch_id(self, mock):
- mock.register_uri('POST', '//livy:8998/batches', json={}, status_code=201)
-
+ def test_missing_batch_id(self, httpx_mock, setup_connections):
+ httpx_mock.add_response(method='POST', url='http://livy:8998/batches', json={}, status_code=201)
hook = LivyHook()
with pytest.raises(AirflowException):
hook.post_batch(file='sparkapp')
- @requests_mock.mock()
- def test_get_batch_validation(self, mock):
- mock.register_uri('GET', f'//livy:8998/batches/{BATCH_ID}', json=SAMPLE_GET_RESPONSE, status_code=200)
+ def test_get_batch_validation(self, httpx_mock, setup_connections):
+ httpx_mock.add_response(
+ method='GET',
+ url=f'http://livy:8998/batches/{BATCH_ID}',
+ json=SAMPLE_GET_RESPONSE,
+ status_code=200,
+ )
hook = LivyHook()
- with self.subTest('get_batch'):
- hook.get_batch(BATCH_ID)
+ hook.get_batch(BATCH_ID)
# make sure blocked by validation
for val in [None, 'one', {'a': 'b'}]:
- with self.subTest(f'get_batch {val}'):
- with pytest.raises(TypeError):
- hook.get_batch(val)
-
- @requests_mock.mock()
- def test_get_batch_state_validation(self, mock):
- mock.register_uri(
- 'GET', f'//livy:8998/batches/{BATCH_ID}/state', json=SAMPLE_GET_RESPONSE, status_code=200
+ with pytest.raises(TypeError):
+ hook.get_batch(val)
+
+ def test_get_batch_state_validation(self, httpx_mock, setup_connections):
+ httpx_mock.add_response(
+ method='GET',
+ url=f'http://livy:8998/batches/{BATCH_ID}/state',
+ json=SAMPLE_GET_RESPONSE,
+ status_code=200,
)
hook = LivyHook()
- with self.subTest('get_batch'):
- hook.get_batch_state(BATCH_ID)
+ hook.get_batch_state(BATCH_ID)
for val in [None, 'one', {'a': 'b'}]:
- with self.subTest(f'get_batch {val}'):
- with pytest.raises(TypeError):
- hook.get_batch_state(val)
+ with pytest.raises(TypeError):
+ hook.get_batch_state(val)
- @requests_mock.mock()
- def test_delete_batch_validation(self, mock):
- mock.register_uri('DELETE', f'//livy:8998/batches/{BATCH_ID}', json={'id': BATCH_ID}, status_code=200)
+ def test_delete_batch_validation(self, httpx_mock, setup_connections):
+ httpx_mock.add_response(
+ method='DELETE',
+ url=f'http://livy:8998/batches/{BATCH_ID}',
+ json={'id': BATCH_ID},
+ status_code=200,
+ )
hook = LivyHook()
- with self.subTest('get_batch'):
- hook.delete_batch(BATCH_ID)
+ hook.delete_batch(BATCH_ID)
for val in [None, 'one', {'a': 'b'}]:
- with self.subTest(f'get_batch {val}'):
- with pytest.raises(TypeError):
- hook.delete_batch(val)
-
- def test_check_session_id(self):
- with self.subTest('valid 00'):
- try:
- LivyHook._validate_session_id(100)
- except TypeError:
- self.fail("")
-
- with self.subTest('valid 01'):
- try:
- LivyHook._validate_session_id(0)
- except TypeError:
- self.fail("")
-
- with self.subTest('None'):
with pytest.raises(TypeError):
- LivyHook._validate_session_id(None) # noqa
-
- with self.subTest('random string'):
- with pytest.raises(TypeError):
- LivyHook._validate_session_id('asd')
+ hook.delete_batch(val)
diff --git a/tests/providers/http/hooks/test_http.py b/tests/providers/http/hooks/test_http.py
index 825847b982156..3a0a20714207e 100644
--- a/tests/providers/http/hooks/test_http.py
+++ b/tests/providers/http/hooks/test_http.py
@@ -17,15 +17,11 @@
# under the License.
import json
import os
-import unittest
-from collections import OrderedDict
from unittest import mock
+import httpx
import pytest
-import requests
-import requests_mock
import tenacity
-from parameterized import parameterized
from airflow.exceptions import AirflowException
from airflow.models import Connection
@@ -40,41 +36,42 @@ def get_airflow_connection_with_port(unused_conn_id=None):
return Connection(conn_id='http_default', conn_type='http', host='test.com', port=1234)
-class TestHttpHook(unittest.TestCase):
- """Test get, post and raise_for_status"""
+@pytest.fixture
+def setup_hook():
+ yield HttpHook(method='GET')
+
+
+@pytest.fixture
+def setup_lowercase_hook():
+ yield HttpHook(method='get')
- def setUp(self):
- session = requests.Session()
- adapter = requests_mock.Adapter()
- session.mount('mock', adapter)
- self.get_hook = HttpHook(method='GET')
- self.get_lowercase_hook = HttpHook(method='get')
- self.post_hook = HttpHook(method='POST')
- @requests_mock.mock()
- def test_raise_for_status_with_200(self, m):
+@pytest.fixture
+def setup_post_hook():
+ yield HttpHook(method='POST')
+
+
+class TestHttpHook:
+ """Test get, post and raise_for_status"""
- m.get('http://test:8080/v1/test', status_code=200, text='{"status":{"status": 200}}', reason='OK')
+ def test_raise_for_status_with_200(self, httpx_mock, setup_hook):
+ httpx_mock.add_response(
+ url='http://test:8080/v1/test', status_code=200, data='{"status":{"status": 200}}'
+ )
with mock.patch('airflow.hooks.base.BaseHook.get_connection', side_effect=get_airflow_connection):
- resp = self.get_hook.run('v1/test')
+ resp = setup_hook.run('v1/test')
assert resp.text == '{"status":{"status": 200}}'
- @requests_mock.mock()
- @mock.patch('requests.Session')
- @mock.patch('requests.Request')
- def test_get_request_with_port(self, mock_requests, request_mock, mock_session):
- from requests.exceptions import MissingSchema
-
+ @mock.patch('httpx.Client')
+ @mock.patch('httpx.Request')
+ def test_get_request_with_port(self, request_mock, mock_client, setup_hook):
with mock.patch(
'airflow.hooks.base.BaseHook.get_connection', side_effect=get_airflow_connection_with_port
):
expected_url = 'http://test.com:1234/some/endpoint'
for endpoint in ['some/endpoint', '/some/endpoint']:
- try:
- self.get_hook.run(endpoint)
- except MissingSchema:
- pass
+ setup_hook.run(endpoint)
request_mock.assert_called_once_with(
mock.ANY, expected_url, headers=mock.ANY, params=mock.ANY
@@ -82,117 +79,116 @@ def test_get_request_with_port(self, mock_requests, request_mock, mock_session):
request_mock.reset_mock()
- @requests_mock.mock()
- def test_get_request_do_not_raise_for_status_if_check_response_is_false(self, m):
-
- m.get(
- 'http://test:8080/v1/test',
+ def test_get_request_do_not_raise_for_status_if_check_response_is_false(self, httpx_mock, setup_hook):
+ httpx_mock.add_response(
+ url='http://test:8080/v1/test',
status_code=404,
- text='{"status":{"status": 404}}',
- reason='Bad request',
+ data='{"status":{"status": 404}}',
)
with mock.patch('airflow.hooks.base.BaseHook.get_connection', side_effect=get_airflow_connection):
- resp = self.get_hook.run('v1/test', extra_options={'check_response': False})
+ resp = setup_hook.run('v1/test', extra_options={'check_response': False})
assert resp.text == '{"status":{"status": 404}}'
- @requests_mock.mock()
- def test_hook_contains_header_from_extra_field(self, mock_requests):
+ def test_hook_contains_header_from_extra_field(self, setup_hook):
with mock.patch('airflow.hooks.base.BaseHook.get_connection', side_effect=get_airflow_connection):
expected_conn = get_airflow_connection()
- conn = self.get_hook.get_conn()
+ conn = setup_hook.get_conn()
assert dict(conn.headers, **json.loads(expected_conn.extra)) == conn.headers
assert conn.headers.get('bareer') == 'test'
- @requests_mock.mock()
- @mock.patch('requests.Request')
- def test_hook_with_method_in_lowercase(self, mock_requests, request_mock):
- from requests.exceptions import InvalidURL, MissingSchema
-
+ def test_hook_with_method_in_lowercase(self, httpx_mock, setup_lowercase_hook):
+ data = "test_params=aaaa"
+ httpx_mock.add_response(
+ url='http://test.com:1234/v1/test?test_params=aaaa',
+ status_code=200,
+ data='{"status":{"status": 200}}',
+ )
with mock.patch(
'airflow.hooks.base.BaseHook.get_connection', side_effect=get_airflow_connection_with_port
):
- data = "test params"
try:
- self.get_lowercase_hook.run('v1/test', data=data)
- except (MissingSchema, InvalidURL):
+ setup_lowercase_hook.run('v1/test', data=data)
+ except ConnectionError:
pass
- request_mock.assert_called_once_with(mock.ANY, mock.ANY, headers=mock.ANY, params=data)
- @requests_mock.mock()
- def test_hook_uses_provided_header(self, mock_requests):
- conn = self.get_hook.get_conn(headers={"bareer": "newT0k3n"})
+ def test_hook_uses_provided_header(self, setup_hook):
+ conn = setup_hook.get_conn(headers={"bareer": "newT0k3n"})
assert conn.headers.get('bareer') == "newT0k3n"
- @requests_mock.mock()
- def test_hook_has_no_header_from_extra(self, mock_requests):
- conn = self.get_hook.get_conn()
+ def test_hook_has_no_header_from_extra(self, setup_hook):
+ conn = setup_hook.get_conn()
assert conn.headers.get('bareer') is None
- @requests_mock.mock()
- def test_hooks_header_from_extra_is_overridden(self, mock_requests):
+ def test_hooks_header_from_extra_is_overridden(self, setup_hook):
with mock.patch('airflow.hooks.base.BaseHook.get_connection', side_effect=get_airflow_connection):
- conn = self.get_hook.get_conn(headers={"bareer": "newT0k3n"})
- assert conn.headers.get('bareer') == 'newT0k3n'
+ conn = setup_hook.get_conn()
+ assert conn.headers.get('bareer') == 'test'
- @requests_mock.mock()
- def test_post_request(self, mock_requests):
- mock_requests.post(
- 'http://test:8080/v1/test', status_code=200, text='{"status":{"status": 200}}', reason='OK'
+ def test_hooks_header_from_extra_is_overridden_and_used(self, httpx_mock, setup_hook):
+ httpx_mock.add_response(
+ url='http://test:8080/v1/test',
+ status_code=200,
+ data='{"status":{"status": 200}}',
+ match_headers={"bareer": "test"},
+ )
+ with mock.patch('airflow.hooks.base.BaseHook.get_connection', side_effect=get_airflow_connection):
+ setup_hook.run('v1/test', extra_options={'check_response': False})
+
+ def test_post_request(self, httpx_mock, setup_post_hook):
+ httpx_mock.add_response(
+ method='POST', url='http://test:8080/v1/test', status_code=200, data='{"status":{"status": 200}}'
)
with mock.patch('airflow.hooks.base.BaseHook.get_connection', side_effect=get_airflow_connection):
- resp = self.post_hook.run('v1/test')
+ resp = setup_post_hook.run('v1/test')
assert resp.status_code == 200
- @requests_mock.mock()
- def test_post_request_with_error_code(self, mock_requests):
- mock_requests.post(
- 'http://test:8080/v1/test',
+ def test_post_request_with_error_code(self, httpx_mock, setup_post_hook):
+ httpx_mock.add_response(
+ method='POST',
+ url='http://test:8080/v1/test',
status_code=418,
- text='{"status":{"status": 418}}',
- reason='I\'m a teapot',
)
with mock.patch('airflow.hooks.base.BaseHook.get_connection', side_effect=get_airflow_connection):
with pytest.raises(AirflowException):
- self.post_hook.run('v1/test')
-
- @requests_mock.mock()
- def test_post_request_do_not_raise_for_status_if_check_response_is_false(self, mock_requests):
- mock_requests.post(
- 'http://test:8080/v1/test',
+ setup_post_hook.run('v1/test')
+
+ def test_post_request_do_not_raise_for_status_if_check_response_is_false(
+ self, httpx_mock, setup_post_hook
+ ):
+ httpx_mock.add_response(
+ method='POST',
+ url='http://test:8080/v1/test',
status_code=418,
- text='{"status":{"status": 418}}',
- reason='I\'m a teapot',
)
with mock.patch('airflow.hooks.base.BaseHook.get_connection', side_effect=get_airflow_connection):
- resp = self.post_hook.run('v1/test', extra_options={'check_response': False})
+ resp = setup_post_hook.run('v1/test', extra_options={'check_response': False})
assert resp.status_code == 418
- @mock.patch('airflow.providers.http.hooks.http.requests.Session')
- def test_retry_on_conn_error(self, mocked_session):
+ @mock.patch('airflow.providers.http.hooks.http.httpx.Client')
+ def test_retry_on_conn_error(self, mocked_client, setup_hook):
retry_args = dict(
wait=tenacity.wait_none(),
stop=tenacity.stop_after_attempt(7),
- retry=tenacity.retry_if_exception_type(requests.exceptions.ConnectionError),
+ retry=tenacity.retry_if_exception_type(httpx.NetworkError),
)
def send_and_raise(unused_request, **kwargs):
- raise requests.exceptions.ConnectionError
+ raise httpx.NetworkError(message="ConnectionError")
- mocked_session().send.side_effect = send_and_raise
+ mocked_client().send.side_effect = send_and_raise
# The job failed for some reason
with pytest.raises(tenacity.RetryError):
- self.get_hook.run_with_advanced_retry(endpoint='v1/test', _retry_args=retry_args)
- assert self.get_hook._retry_obj.stop.max_attempt_number + 1 == mocked_session.call_count
+ setup_hook.run_with_advanced_retry(endpoint='v1/test', _retry_args=retry_args)
+ assert setup_hook._retry_obj.stop.max_attempt_number + 1 == mocked_client.call_count
- @requests_mock.mock()
- def test_run_with_advanced_retry(self, m):
+ def test_run_with_advanced_retry(self, httpx_mock, setup_hook):
- m.get('http://test:8080/v1/test', status_code=200, reason='OK')
+ httpx_mock.add_response(url='http://test:8080/v1/test', status_code=200)
retry_args = dict(
wait=tenacity.wait_none(),
@@ -201,22 +197,18 @@ def test_run_with_advanced_retry(self, m):
reraise=True,
)
with mock.patch('airflow.hooks.base.BaseHook.get_connection', side_effect=get_airflow_connection):
- response = self.get_hook.run_with_advanced_retry(endpoint='v1/test', _retry_args=retry_args)
- assert isinstance(response, requests.Response)
-
- def test_header_from_extra_and_run_method_are_merged(self):
- def run_and_return(unused_session, prepped_request, unused_extra_options, **kwargs):
- return prepped_request
-
+ response = setup_hook.run_with_advanced_retry(endpoint='v1/test', _retry_args=retry_args)
+ assert isinstance(response, httpx.Response)
+
+ def test_header_from_extra_and_run_method_are_merged(self, httpx_mock, setup_hook):
+ httpx_mock.add_response(
+ url='http://test:8080/v1/test',
+ status_code=200,
+ match_headers={"bareer": "test", "some_other_header": "test"},
+ )
# The job failed for some reason
- with mock.patch(
- 'airflow.providers.http.hooks.http.HttpHook.run_and_check', side_effect=run_and_return
- ):
- with mock.patch('airflow.hooks.base.BaseHook.get_connection', side_effect=get_airflow_connection):
- prepared_request = self.get_hook.run('v1/test', headers={'some_other_header': 'test'})
- actual = dict(prepared_request.headers)
- assert actual.get('bareer') == 'test'
- assert actual.get('some_other_header') == 'test'
+ with mock.patch('airflow.hooks.base.BaseHook.get_connection', side_effect=get_airflow_connection):
+ setup_hook.run('v1/test', headers={'some_other_header': 'test'})
@mock.patch('airflow.providers.http.hooks.http.HttpHook.get_connection')
def test_http_connection(self, mock_get_connection):
@@ -250,8 +242,8 @@ def test_host_encoded_https_connection(self, mock_get_connection):
hook.get_conn({})
assert hook.base_url == 'https://localhost'
- def test_method_converted_to_uppercase_when_created_in_lowercase(self):
- assert self.get_lowercase_hook.method == 'GET'
+ def test_method_converted_to_uppercase_when_created_in_lowercase(self, setup_lowercase_hook):
+ assert setup_lowercase_hook.method == 'GET'
@mock.patch('airflow.providers.http.hooks.http.HttpHook.get_connection')
def test_connection_without_host(self, mock_get_connection):
@@ -262,96 +254,90 @@ def test_connection_without_host(self, mock_get_connection):
hook.get_conn({})
assert hook.base_url == 'http://'
- @parameterized.expand(
- [
- 'GET',
- 'POST',
- ]
- )
- @requests_mock.mock()
- def test_json_request(self, method, mock_requests):
+ def test_json_request_get(self, httpx_mock):
obj1 = {'a': 1, 'b': 'abc', 'c': [1, 2, {"d": 10}]}
- def match_obj1(request):
- return request.json() == obj1
+ httpx_mock.add_response(
+ method='GET', url='http://test:8080/v1/test', match_content=json.dumps(obj1).encode('utf-8')
+ )
- mock_requests.request(method=method, url='//test:8080/v1/test', additional_matcher=match_obj1)
+ with mock.patch('airflow.hooks.base.BaseHook.get_connection', side_effect=get_airflow_connection):
+ # will raise NoMockAddress exception if obj1 != request.json()
+ HttpHook(method='GET').run('v1/test', json=obj1)
+
+ def test_json_request_post(self, httpx_mock):
+ obj1 = {'a': 1, 'b': 'abc', 'c': [1, 2, {"d": 10}]}
+
+ httpx_mock.add_response(
+ method='POST', url='http://test:8080/v1/test', match_content=json.dumps(obj1).encode('utf-8')
+ )
with mock.patch('airflow.hooks.base.BaseHook.get_connection', side_effect=get_airflow_connection):
# will raise NoMockAddress exception if obj1 != request.json()
- HttpHook(method=method).run('v1/test', json=obj1)
+ HttpHook(method='POST').run('v1/test', json=obj1)
- @mock.patch('airflow.providers.http.hooks.http.requests.Session.send')
- def test_verify_set_to_true_by_default(self, mock_session_send):
+ @mock.patch('airflow.providers.http.hooks.http.httpx.Client')
+ def test_verify_set_to_true_by_default(self, mock_client, setup_hook):
with mock.patch(
'airflow.hooks.base.BaseHook.get_connection', side_effect=get_airflow_connection_with_port
):
- self.get_hook.run('/some/endpoint')
- mock_session_send.assert_called_once_with(
- mock.ANY,
- allow_redirects=True,
- cert=None,
- proxies=OrderedDict(),
- stream=False,
- timeout=None,
- verify=True,
+ setup_hook.run('/some/endpoint')
+ mock_client.assert_called_once_with(verify=True, cert=None, proxies=None)
+
+ @mock.patch('airflow.providers.http.hooks.http.httpx.Client')
+ def test_verify_can_be_overridden(self, mock_client, setup_hook):
+ with mock.patch('airflow.hooks.base.BaseHook.get_connection', side_effect=get_airflow_connection):
+ setup_hook.run('/some/endpoint', extra_options={'verify': False})
+ mock_client.assert_called_once_with(verify=False, cert=None, proxies=None)
+
+ @mock.patch('airflow.providers.http.hooks.http.httpx.Client')
+ def test_cert_can_be_overridden(self, mock_client, setup_hook):
+ with mock.patch('airflow.hooks.base.BaseHook.get_connection', side_effect=get_airflow_connection):
+ setup_hook.run('/some/endpoint', extra_options={'cert': '/tmp/private.crt'})
+ mock_client.assert_called_once_with(verify=True, cert='/tmp/private.crt', proxies=None)
+
+ @mock.patch('airflow.providers.http.hooks.http.httpx.Client')
+ def test_proxies_can_be_overridden(self, mock_client, setup_hook):
+ with mock.patch('airflow.hooks.base.BaseHook.get_connection', side_effect=get_airflow_connection):
+ setup_hook.run('/some/endpoint', extra_options={'proxies': {"http://localhost": 'http://proxy'}})
+ mock_client.assert_called_once_with(
+ verify=True, cert=None, proxies={"http://localhost": 'http://proxy'}
)
- @mock.patch('airflow.providers.http.hooks.http.requests.Session.send')
- @mock.patch.dict(os.environ, {"REQUESTS_CA_BUNDLE": "/tmp/test.crt"})
- def test_requests_ca_bundle_env_var(self, mock_session_send):
+ @mock.patch('airflow.providers.http.hooks.http.httpx.Client')
+ def test_verifu_parameter_set(self, mock_client, setup_hook):
with mock.patch(
'airflow.hooks.base.BaseHook.get_connection', side_effect=get_airflow_connection_with_port
):
- self.get_hook.run('/some/endpoint')
-
- mock_session_send.assert_called_once_with(
- mock.ANY,
- allow_redirects=True,
- cert=None,
- proxies=OrderedDict(),
- stream=False,
- timeout=None,
- verify='/tmp/test.crt',
- )
+ setup_hook.run('/some/endpoint', extra_options={'verify': '/tmp/overridden.crt'})
+ mock_client.assert_called_once_with(cert=None, proxies=None, verify='/tmp/overridden.crt')
- @mock.patch('airflow.providers.http.hooks.http.requests.Session.send')
+ @mock.patch('airflow.providers.http.hooks.http.httpx.Client')
@mock.patch.dict(os.environ, {"REQUESTS_CA_BUNDLE": "/tmp/test.crt"})
- def test_verify_respects_requests_ca_bundle_env_var(self, mock_session_send):
+ def test_requests_ca_bundle_env_var(self, mock_client, setup_hook):
with mock.patch(
'airflow.hooks.base.BaseHook.get_connection', side_effect=get_airflow_connection_with_port
):
- self.get_hook.run('/some/endpoint', extra_options={'verify': True})
-
- mock_session_send.assert_called_once_with(
- mock.ANY,
- allow_redirects=True,
- cert=None,
- proxies=OrderedDict(),
- stream=False,
- timeout=None,
- verify='/tmp/test.crt',
- )
+ setup_hook.run('/some/endpoint')
+ mock_client.assert_called_once_with(cert=None, proxies=None, verify='/tmp/test.crt')
- @mock.patch('airflow.providers.http.hooks.http.requests.Session.send')
+ @mock.patch('airflow.providers.http.hooks.http.httpx.Client')
@mock.patch.dict(os.environ, {"REQUESTS_CA_BUNDLE": "/tmp/test.crt"})
- def test_verify_false_parameter_overwrites_set_requests_ca_bundle_env_var(self, mock_session_send):
+ def test_verify_respects_requests_ca_bundle_env_var(self, mock_client, setup_hook):
with mock.patch(
'airflow.hooks.base.BaseHook.get_connection', side_effect=get_airflow_connection_with_port
):
- self.get_hook.run('/some/endpoint', extra_options={'verify': False})
-
- mock_session_send.assert_called_once_with(
- mock.ANY,
- allow_redirects=True,
- cert=None,
- proxies=OrderedDict(),
- stream=False,
- timeout=None,
- verify=False,
- )
+ setup_hook.run('/some/endpoint', extra_options={'verify': True})
+ mock_client.assert_called_once_with(cert=None, proxies=None, verify='/tmp/test.crt')
-send_email_test = mock.Mock()
+ @mock.patch('airflow.providers.http.hooks.http.httpx.Client')
+ @mock.patch.dict(os.environ, {"REQUESTS_CA_BUNDLE": "/tmp/test.crt"})
+ def test_verify_false_parameter_overwrites_set_requests_ca_bundle_env_var(self, mock_client, setup_hook):
+ with mock.patch(
+ 'airflow.hooks.base.BaseHook.get_connection', side_effect=get_airflow_connection_with_port
+ ):
+ setup_hook.run('/some/endpoint', extra_options={'verify': False})
+ mock_client.assert_called_once_with(cert=None, proxies=None, verify=False)
diff --git a/tests/providers/http/operators/test_http.py b/tests/providers/http/operators/test_http.py
index d4c622c29e6a5..dc62c0371c740 100644
--- a/tests/providers/http/operators/test_http.py
+++ b/tests/providers/http/operators/test_http.py
@@ -16,26 +16,23 @@
# specific language governing permissions and limitations
# under the License.
-import unittest
from unittest import mock
import pytest
-import requests_mock
from airflow.exceptions import AirflowException
from airflow.providers.http.operators.http import SimpleHttpOperator
@mock.patch.dict('os.environ', AIRFLOW_CONN_HTTP_EXAMPLE='http://www.example.com')
-class TestSimpleHttpOp(unittest.TestCase):
- @requests_mock.mock()
- def test_response_in_logs(self, m):
+class TestSimpleHttpOp:
+ def test_response_in_logs(self, httpx_mock):
"""
Test that when using SimpleHttpOperator with 'GET',
the log contains 'Example Domain' in it
"""
- m.get('http://www.example.com', text='Example.com fake response')
+ httpx_mock.add_response(url='http://www.example.com', data='Example.com fake response')
operator = SimpleHttpOperator(
task_id='test_HTTP_op',
method='GET',
@@ -49,8 +46,7 @@ def test_response_in_logs(self, m):
calls = [mock.call('Example.com fake response'), mock.call('Example.com fake response')]
mock_info.has_calls(calls)
- @requests_mock.mock()
- def test_response_in_logs_after_failed_check(self, m):
+ def test_response_in_logs_after_failed_check(self, httpx_mock):
"""
Test that when using SimpleHttpOperator with log_response=True,
the response is logged even if request_check fails
@@ -59,7 +55,7 @@ def test_response_in_logs_after_failed_check(self, m):
def response_check(response):
return response.text != 'invalid response'
- m.get('http://www.example.com', text='invalid response')
+ httpx_mock.add_response(url='http://www.example.com', data='invalid response')
operator = SimpleHttpOperator(
task_id='test_HTTP_op',
method='GET',
@@ -75,9 +71,8 @@ def response_check(response):
calls = [mock.call('Calling HTTP method'), mock.call('invalid response')]
mock_info.assert_has_calls(calls, any_order=True)
- @requests_mock.mock()
- def test_filters_response(self, m):
- m.get('http://www.example.com', json={'value': 5})
+ def test_filters_response(self, httpx_mock):
+ httpx_mock.add_response(url='http://www.example.com', json={'value': 5})
operator = SimpleHttpOperator(
task_id='test_HTTP_op',
method='GET',
diff --git a/tests/providers/http/sensors/test_http.py b/tests/providers/http/sensors/test_http.py
index 23ac2fdf2f69f..595deab718cda 100644
--- a/tests/providers/http/sensors/test_http.py
+++ b/tests/providers/http/sensors/test_http.py
@@ -15,12 +15,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-import unittest
from unittest import mock
-from unittest.mock import patch
import pytest
-import requests
from airflow.exceptions import AirflowException, AirflowSensorTimeout
from airflow.models import TaskInstance
@@ -34,23 +31,22 @@
TEST_DAG_ID = 'unit_test_dag'
-class TestHttpSensor(unittest.TestCase):
- def setUp(self):
- args = {'owner': 'airflow', 'start_date': DEFAULT_DATE}
- self.dag = DAG(TEST_DAG_ID, default_args=args)
+@pytest.fixture
+def setup_dag():
+ args = {'owner': 'airflow', 'start_date': DEFAULT_DATE}
+ yield DAG(TEST_DAG_ID, default_args=args)
- @patch("airflow.providers.http.hooks.http.requests.Session.send")
- def test_poke_exception(self, mock_session_send):
+
+class TestHttpSensor:
+ def test_poke_exception(self, httpx_mock):
"""
Exception occurs in poke function should not be ignored.
"""
- response = requests.Response()
- response.status_code = 200
- mock_session_send.return_value = response
def resp_check(_):
raise AirflowException('AirflowException raised here!')
+ httpx_mock.add_response(status_code=200)
task = HttpSensor(
task_id='http_sensor_poke_exception',
http_conn_id='http_default',
@@ -63,19 +59,14 @@ def resp_check(_):
with pytest.raises(AirflowException, match='AirflowException raised here!'):
task.execute(context={})
- @patch("airflow.providers.http.hooks.http.requests.Session.send")
- def test_poke_continues_for_http_500_with_extra_options_check_response_false(self, mock_session_send):
+ def test_poke_continues_for_http_500_with_extra_options_check_response_false(self, httpx_mock, setup_dag):
def resp_check(_):
return False
- response = requests.Response()
- response.status_code = 500
- response.reason = 'Internal Server Error'
- response._content = b'Internal Server Error'
- mock_session_send.return_value = response
+ httpx_mock.add_response(status_code=500, data="Internal Server Error")
task = HttpSensor(
- dag=self.dag,
+ dag=setup_dag,
task_id='http_sensor_poke_for_code_500',
http_conn_id='http_default',
endpoint='',
@@ -87,16 +78,16 @@ def resp_check(_):
poke_interval=1,
)
- with self.assertRaises(AirflowSensorTimeout):
+ with pytest.raises(AirflowSensorTimeout, match='Snap. Time is OUT. DAG id: unit_test_dag'):
task.execute(context={})
- @patch("airflow.providers.http.hooks.http.requests.Session.send")
- def test_head_method(self, mock_session_send):
+ def test_head_method(self, httpx_mock, setup_dag):
def resp_check(_):
return True
+ httpx_mock.add_response(status_code=200, url='https://www.httpbin.org', method='HEAD')
task = HttpSensor(
- dag=self.dag,
+ dag=setup_dag,
task_id='http_sensor_head_method',
http_conn_id='http_default',
endpoint='',
@@ -109,19 +100,8 @@ def resp_check(_):
task.execute(context={})
- args, kwargs = mock_session_send.call_args
- received_request = args[0]
-
- prep_request = requests.Request('HEAD', 'https://www.httpbin.org', {}).prepare()
-
- assert prep_request.url == received_request.url
- assert prep_request.method, received_request.method
-
- @patch("airflow.providers.http.hooks.http.requests.Session.send")
- def test_poke_context(self, mock_session_send):
- response = requests.Response()
- response.status_code = 200
- mock_session_send.return_value = response
+ def test_poke_context(self, httpx_mock, setup_dag):
+ httpx_mock.add_response(status_code=200, url='https://www.httpbin.org')
def resp_check(_, execution_date):
if execution_date == DEFAULT_DATE:
@@ -136,25 +116,19 @@ def resp_check(_, execution_date):
response_check=resp_check,
timeout=5,
poke_interval=1,
- dag=self.dag,
+ dag=setup_dag,
)
task_instance = TaskInstance(task=task, execution_date=DEFAULT_DATE)
task.execute(task_instance.get_template_context())
- @patch("airflow.providers.http.hooks.http.requests.Session.send")
- def test_logging_head_error_request(self, mock_session_send):
+ def test_logging_head_error_request(self, httpx_mock, setup_dag):
def resp_check(_):
return True
- response = requests.Response()
- response.status_code = 404
- response.reason = 'Not Found'
- response._content = b"This endpoint doesn't exist"
- mock_session_send.return_value = response
-
+ httpx_mock.add_response(status_code=404, data="This endpoint doesn't exist")
task = HttpSensor(
- dag=self.dag,
+ dag=setup_dag,
task_id='http_sensor_head_method',
http_conn_id='http_default',
endpoint='',
@@ -167,7 +141,7 @@ def resp_check(_):
with mock.patch.object(task.hook.log, 'error') as mock_errors:
with pytest.raises(AirflowSensorTimeout):
- task.execute(None)
+ task.execute(context={})
assert mock_errors.called
calls = [
@@ -187,68 +161,51 @@ def resp_check(_):
mock_errors.assert_has_calls(calls)
-class FakeSession:
- def __init__(self):
- self.response = requests.Response()
- self.response.status_code = 200
- self.response._content = 'apache/airflow'.encode('ascii', 'ignore')
-
- def send(self, *args, **kwargs):
- return self.response
+@pytest.fixture
+def setup_op_dag():
+ args = {'owner': 'airflow', 'start_date': DEFAULT_DATE_ISO}
+ yield DAG(TEST_DAG_ID, default_args=args)
- def prepare_request(self, request):
- if 'date' in request.params:
- self.response._content += ('/' + request.params['date']).encode('ascii', 'ignore')
- return self.response
- def merge_environment_settings(self, _url, **kwargs):
- return kwargs
-
-
-class TestHttpOpSensor(unittest.TestCase):
- def setUp(self):
- args = {'owner': 'airflow', 'start_date': DEFAULT_DATE_ISO}
- dag = DAG(TEST_DAG_ID, default_args=args)
- self.dag = dag
-
- @mock.patch('requests.Session', FakeSession)
- def test_get(self):
+class TestHttpOpSensor:
+ def test_get(self, httpx_mock, setup_op_dag):
+ httpx_mock.add_response(status_code=200)
op = SimpleHttpOperator(
task_id='get_op',
method='GET',
endpoint='/search',
data={"client": "ubuntu", "q": "airflow"},
headers={},
- dag=self.dag,
+ dag=setup_op_dag,
)
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
- @mock.patch('requests.Session', FakeSession)
- def test_get_response_check(self):
+ def test_get_response_check(self, httpx_mock, setup_op_dag):
+ httpx_mock.add_response(status_code=200, data="apache/airflow")
op = SimpleHttpOperator(
task_id='get_op',
method='GET',
endpoint='/search',
- data={"client": "ubuntu", "q": "airflow"},
response_check=lambda response: ("apache/airflow" in response.text),
headers={},
- dag=self.dag,
+ dag=setup_op_dag,
)
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
- @mock.patch('requests.Session', FakeSession)
- def test_sensor(self):
+ def test_sensor(self, httpx_mock, setup_op_dag):
+ httpx_mock.add_response(
+ status_code=200,
+ url="https://www.httpbin.org//search?client=ubuntu&q=airflow&date="
+ + DEFAULT_DATE.strftime('%Y-%m-%d'),
+ )
sensor = HttpSensor(
task_id='http_sensor_check',
http_conn_id='http_default',
endpoint='/search',
request_params={"client": "ubuntu", "q": "airflow", 'date': '{{ds}}'},
headers={},
- response_check=lambda response: (
- "apache/airflow/" + DEFAULT_DATE.strftime('%Y-%m-%d') in response.text
- ),
poke_interval=5,
timeout=15,
- dag=self.dag,
+ dag=setup_op_dag,
)
sensor.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
diff --git a/tests/providers/opsgenie/hooks/test_opsgenie_alert.py b/tests/providers/opsgenie/hooks/test_opsgenie_alert.py
index 0db9ca4cdbdce..04743c93cd5db 100644
--- a/tests/providers/opsgenie/hooks/test_opsgenie_alert.py
+++ b/tests/providers/opsgenie/hooks/test_opsgenie_alert.py
@@ -20,7 +20,6 @@
import unittest
import pytest
-import requests_mock
from airflow.exceptions import AirflowException
from airflow.models import Connection
@@ -31,6 +30,36 @@
class TestOpsgenieAlertHook(unittest.TestCase):
conn_id = 'opsgenie_conn_id_test'
opsgenie_alert_endpoint = 'https://api.opsgenie.com/v2/alerts'
+
+ def setUp(self):
+ db.merge_conn(
+ Connection(
+ conn_id=self.conn_id,
+ conn_type='http',
+ host='https://api.opsgenie.com/',
+ password='eb243592-faa2-4ba2-a551q-1afdf565c889',
+ )
+ )
+
+ def test_get_api_key(self):
+ hook = OpsgenieAlertHook(opsgenie_conn_id=self.conn_id)
+ api_key = hook._get_api_key()
+ assert 'eb243592-faa2-4ba2-a551q-1afdf565c889' == api_key
+
+ def test_get_conn_defaults_host(self):
+ hook = OpsgenieAlertHook()
+ hook.get_conn()
+ assert 'https://api.opsgenie.com' == hook.base_url
+
+
+class TestOpsGenieAlertMockHttpx:
+ conn_id = 'opsgenie_conn_id_test'
+ opsgenie_alert_endpoint = 'https://api.opsgenie.com/v2/alerts'
+ _mock_success_response_body = {
+ "result": "Request will be processed",
+ "took": 0.302,
+ "requestId": "43a29c5c-3dbf-4fa4-9c26-f4f71023e120",
+ }
_payload = {
'message': 'An example alert message',
'alias': 'Life is too short for no alias',
@@ -60,57 +89,35 @@ class TestOpsgenieAlertHook(unittest.TestCase):
'user': 'Jesse',
'note': 'Write this down',
}
- _mock_success_response_body = {
- "result": "Request will be processed",
- "took": 0.302,
- "requestId": "43a29c5c-3dbf-4fa4-9c26-f4f71023e120",
- }
-
- def setUp(self):
- db.merge_conn(
- Connection(
- conn_id=self.conn_id,
- conn_type='http',
- host='https://api.opsgenie.com/',
- password='eb243592-faa2-4ba2-a551q-1afdf565c889',
- )
- )
- def test_get_api_key(self):
- hook = OpsgenieAlertHook(opsgenie_conn_id=self.conn_id)
- api_key = hook._get_api_key()
- assert 'eb243592-faa2-4ba2-a551q-1afdf565c889' == api_key
-
- def test_get_conn_defaults_host(self):
- hook = OpsgenieAlertHook()
- hook.get_conn()
- assert 'https://api.opsgenie.com' == hook.base_url
-
- @requests_mock.mock()
- def test_call_with_success(self, m):
+ def test_call_with_success(self, httpx_mock):
hook = OpsgenieAlertHook(opsgenie_conn_id=self.conn_id)
- m.post(self.opsgenie_alert_endpoint, status_code=202, json=self._mock_success_response_body)
+ httpx_mock.add_response(
+ url=self.opsgenie_alert_endpoint, status_code=202, json=self._mock_success_response_body
+ )
resp = hook.execute(payload=self._payload)
assert resp.status_code == 202
assert resp.json() == self._mock_success_response_body
- @requests_mock.mock()
- def test_api_key_set(self, m):
+ def test_api_key_set(self, httpx_mock):
hook = OpsgenieAlertHook(opsgenie_conn_id=self.conn_id)
- m.post(self.opsgenie_alert_endpoint, status_code=202, json=self._mock_success_response_body)
+ httpx_mock.add_response(
+ url=self.opsgenie_alert_endpoint, status_code=202, json=self._mock_success_response_body
+ )
resp = hook.execute(payload=self._payload)
assert resp.request.headers.get('Authorization') == 'GenieKey eb243592-faa2-4ba2-a551q-1afdf565c889'
- @requests_mock.mock()
- def test_api_key_not_set(self, m):
+ def test_api_key_not_set(self):
hook = OpsgenieAlertHook()
- m.post(self.opsgenie_alert_endpoint, status_code=202, json=self._mock_success_response_body)
with pytest.raises(AirflowException):
hook.execute(payload=self._payload)
- @requests_mock.mock()
- def test_payload_set(self, m):
+ def test_payload_set(self, httpx_mock):
hook = OpsgenieAlertHook(opsgenie_conn_id=self.conn_id)
- m.post(self.opsgenie_alert_endpoint, status_code=202, json=self._mock_success_response_body)
- resp = hook.execute(payload=self._payload)
- assert json.loads(resp.request.body) == self._payload
+ httpx_mock.add_response(
+ url=self.opsgenie_alert_endpoint,
+ status_code=202,
+ json=self._mock_success_response_body,
+ match_content=json.dumps(self._payload).encode("utf-8"),
+ )
+ hook.execute(payload=self._payload)
diff --git a/tests/providers/slack/hooks/test_slack_webhook.py b/tests/providers/slack/hooks/test_slack_webhook.py
index 6fce527a76e23..92603b4388a69 100644
--- a/tests/providers/slack/hooks/test_slack_webhook.py
+++ b/tests/providers/slack/hooks/test_slack_webhook.py
@@ -18,9 +18,6 @@
#
import json
import unittest
-from unittest import mock
-
-from requests.exceptions import MissingSchema
from airflow.models import Connection
from airflow.providers.slack.hooks.slack_webhook import SlackWebhookHook
@@ -129,41 +126,22 @@ def test_build_slack_message(self):
# Then
assert self.expected_message_dict == json.loads(message)
- @mock.patch('requests.Session')
- @mock.patch('requests.Request')
- def test_url_generated_by_http_conn_id(self, mock_request, mock_session):
+
+class TestSlackWebhookMockHttpx:
+ expected_url = 'https://hooks.slack.com/services/T000/B000/XXX'
+ expected_method = 'POST'
+
+ def test_url_generated_by_http_conn_id(self, httpx_mock):
+ httpx_mock.add_response(url=self.expected_url, method=self.expected_method)
hook = SlackWebhookHook(http_conn_id='slack-webhook-url')
- try:
- hook.execute()
- except MissingSchema:
- pass
- mock_request.assert_called_once_with(
- self.expected_method, self.expected_url, headers=mock.ANY, data=mock.ANY
- )
- mock_request.reset_mock()
+ hook.execute()
- @mock.patch('requests.Session')
- @mock.patch('requests.Request')
- def test_url_generated_by_endpoint(self, mock_request, mock_session):
+ def test_url_generated_by_endpoint(self, httpx_mock):
+ httpx_mock.add_response(url=self.expected_url, method=self.expected_method)
hook = SlackWebhookHook(webhook_token=self.expected_url)
- try:
- hook.execute()
- except MissingSchema:
- pass
- mock_request.assert_called_once_with(
- self.expected_method, self.expected_url, headers=mock.ANY, data=mock.ANY
- )
- mock_request.reset_mock()
+ hook.execute()
- @mock.patch('requests.Session')
- @mock.patch('requests.Request')
- def test_url_generated_by_http_conn_id_and_endpoint(self, mock_request, mock_session):
+ def test_url_generated_by_http_conn_id_and_endpoint(self, httpx_mock):
+ httpx_mock.add_response(url=self.expected_url, method=self.expected_method)
hook = SlackWebhookHook(http_conn_id='slack-webhook-host', webhook_token='B000/XXX')
- try:
- hook.execute()
- except MissingSchema:
- pass
- mock_request.assert_called_once_with(
- self.expected_method, self.expected_url, headers=mock.ANY, data=mock.ANY
- )
- mock_request.reset_mock()
+ hook.execute()