diff --git a/.gitignore b/.gitignore index 6b1153c..28d81a5 100644 --- a/.gitignore +++ b/.gitignore @@ -139,4 +139,7 @@ dmypy.json .idea/ # MacOS files -.DS_Store \ No newline at end of file +.DS_Store + +# Local test script +/docs/examples/test.py diff --git a/README.md b/README.md index ea7b864..0e8aa67 100644 --- a/README.md +++ b/README.md @@ -203,4 +203,74 @@ class KBCStorageClient(HttpClient): cl = KBCStorageClient("my_token") print(cl.get_files()) -``` \ No newline at end of file +``` + +## Async Usage + +The package also provides an asynchronous version of the HTTP client called AsyncHttpClient. +It allows you to make asynchronous requests using async/await syntax. To use the AsyncHttpClient, import it from keboola.http_client_async: + +```python +from keboola.http_client import AsyncHttpClient +``` + +The AsyncHttpClient class provides similar functionality as the HttpClient class, but with asynchronous methods such as get, post, put, patch, and delete that return awaitable coroutines. +You can use these methods within async functions to perform non-blocking HTTP requests. + +```python +import asyncio +from keboola.http_client import AsyncHttpClient + +async def main(): + base_url = "https://api.example.com/" + async with AsyncHttpClient(base_url) as client: + response = await client.get("endpoint") + + if response.status_code == 200: + data = response.json() + # Process the response data + else: + # Handle the error + +asyncio.run(main()) +``` + +The AsyncHttpClient provides similar initialization and request methods as the HttpClient. +The request methods return awaitable coroutines that can be awaited in an asynchronous context. + +#### Building HTTP client based on AsyncHttpClient Example +This example demonstrates the default use of the HTTPClient as a base for REST API clients. + +```python +import asyncio +from keboola.http_client import AsyncHttpClient + +BASE_URL = 'https://connection.keboola.com/v2/storage' +MAX_RETRIES = 3 + +class KBCStorageClient(AsyncHttpClient): + + def __init__(self, storage_token): + AsyncHttpClient.__init__( + self, + base_url=BASE_URL, + retries=MAX_RETRIES, + backoff_factor=0.3, + retry_status_codes=[429, 500, 502, 504], + auth_header={"X-StorageApi-Token": storage_token} + ) + + async def get_files(self, show_expired=False): + params = {"showExpired": show_expired} + response = await self.get('tables', params=params, timeout=5) + return response + +async def main(): + cl = KBCStorageClient("my_token") + files = await cl.get_files(show_expired=False) + print(files) + +asyncio.run(main()) +``` +**Note:** Since there are no parallel requests being made, you won't notice any speedup for this use case. +For an example where you can see the speedup thanks to async requests, you can view the pokeapi.py in docs/examples. diff --git a/docs/examples/poekapi_async.py b/docs/examples/poekapi_async.py new file mode 100644 index 0000000..915cac0 --- /dev/null +++ b/docs/examples/poekapi_async.py @@ -0,0 +1,61 @@ +import time +import asyncio +from keboola.http_client import AsyncHttpClient +import csv +import httpx +import os + + +async def fetch_pokemon(client, poke_id): + try: + r = await client.get(str(poke_id)) + return r + except httpx.HTTPStatusError as e: + if e.response.status_code == 404: + return None + else: + raise + + +async def save_to_csv(details): + filename = "pokemon_details.csv" + fieldnames = ["name", "height", "weight"] + + file_exists = os.path.isfile(filename) + mode = "a" if file_exists else "w" + + with open(filename, mode, newline="") as csvfile: + writer = csv.DictWriter(csvfile, fieldnames=fieldnames) + + if not file_exists: + writer.writeheader() + + writer.writerow({ + "name": details["name"], + "height": details["height"], + "weight": details["weight"] + }) + + +async def main_async(): + base_url = "https://pokeapi.co/api/v2/pokemon/" + start_time = time.time() + + async with AsyncHttpClient(base_url=base_url, max_requests_per_second=20) as c: + poke_id = 1 + + while True: + details = await fetch_pokemon(c, poke_id) + if details is None: + break + + await save_to_csv(details) + + poke_id += 1 + + end_time = time.time() + print(f"Async: Fetched details for {poke_id - 1} Pokémon in {end_time - start_time:.2f} seconds.") + + +if __name__ == "__main__": + asyncio.run(main_async()) diff --git a/docs/examples/pokeapi_process_multiple.py b/docs/examples/pokeapi_process_multiple.py new file mode 100644 index 0000000..1681877 --- /dev/null +++ b/docs/examples/pokeapi_process_multiple.py @@ -0,0 +1,44 @@ +import asyncio +import csv +import time +from typing import List + +from keboola.http_client import AsyncHttpClient + + +def generate_jobs(nr_of_jobs): + return [{'method': 'GET', 'endpoint': str(endpoint)} for endpoint in range(1, nr_of_jobs+1)] + +def save_to_csv(results: List[dict]): + filename = "pokemon_details.csv" + fieldnames = ["name", "height", "weight"] # Define the fields you want to store + + with open(filename, "w", newline="") as csvfile: + writer = csv.DictWriter(csvfile, fieldnames=fieldnames) + writer.writeheader() + for result in results: + writer.writerow({ + "name": result["name"], + "height": result["height"], + "weight": result["weight"] + }) + +async def main_async(): + base_url = "https://pokeapi.co/api/v2/pokemon/" + start_time = time.time() + + client = AsyncHttpClient(base_url=base_url, max_requests_per_second=20) + + jobs = generate_jobs(1000) + + results = await client.process_multiple(jobs) + await client.close() + + end_time = time.time() + print(f"Fetched details for {len(results)} Pokémon in {end_time - start_time:.2f} seconds.") + + save_to_csv(results) + + +if __name__ == "__main__": + asyncio.run(main_async()) diff --git a/docs/examples/storage_client.py b/docs/examples/storage_client.py new file mode 100644 index 0000000..d82dbd1 --- /dev/null +++ b/docs/examples/storage_client.py @@ -0,0 +1,20 @@ +from keboola.http_client import HttpClient + +BASE_URL = 'https://connection.keboola.com/v2/storage' +MAX_RETRIES = 3 + + +class KBCStorageClient(HttpClient): + + def __init__(self, storage_token): + HttpClient.__init__(self, base_url=BASE_URL, max_retries=MAX_RETRIES, backoff_factor=0.3, + status_forcelist=(429, 500, 502, 504), + default_http_header={"X-StorageApi-Token": storage_token}) + + def get_files(self, show_expired=None): + params = {"include": show_expired} + return self.get('tables', params=params, timeout=5) + +cl = KBCStorageClient("my_token") + +print(cl.get_files()) \ No newline at end of file diff --git a/docs/examples/storage_client_async.py b/docs/examples/storage_client_async.py new file mode 100644 index 0000000..b1ca73f --- /dev/null +++ b/docs/examples/storage_client_async.py @@ -0,0 +1,28 @@ +import asyncio +from keboola.http_client import AsyncHttpClient + +BASE_URL = 'https://connection.keboola.com/v2/storage' +MAX_RETRIES = 3 + +class KBCStorageClient(AsyncHttpClient): + + def __init__(self, storage_token): + super().__init__( + base_url=BASE_URL, + retries=MAX_RETRIES, + backoff_factor=0.3, + retry_status_codes=[429, 500, 502, 504], + auth_header={"X-StorageApi-Token": storage_token} + ) + + async def get_files(self, show_expired=False): + params = {"showExpired": show_expired} + response = await self.get('tables', params=params, timeout=5) + return response + +async def main(): + cl = KBCStorageClient("my_token") + files = await cl.get_files(show_expired=False) + print(files) + +asyncio.run(main()) diff --git a/requirements.txt b/requirements.txt index 663bd1f..af4b5ac 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,3 @@ -requests \ No newline at end of file +requests +httpx==0.27.0 +aiolimiter==1.1.0 diff --git a/setup.py b/setup.py index 72b8c44..7999ea4 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,8 @@ setup_requires=['pytest-runner', 'flake8'], tests_require=['pytest'], install_requires=[ - 'requests' + 'requests', + 'httpx' ], author_email="support@keboola.com", description="General HTTP requests library for Python applications running in Keboola Connection environment", diff --git a/src/keboola/http_client/__init__.py b/src/keboola/http_client/__init__.py index 7d9f3d0..05f88ab 100644 --- a/src/keboola/http_client/__init__.py +++ b/src/keboola/http_client/__init__.py @@ -1 +1,2 @@ -from .http import HttpClient # noqa +from .client import HttpClient # noqa +from .async_client import AsyncHttpClient # noqa diff --git a/src/keboola/http_client/async_client.py b/src/keboola/http_client/async_client.py new file mode 100644 index 0000000..ff93494 --- /dev/null +++ b/src/keboola/http_client/async_client.py @@ -0,0 +1,241 @@ +import httpx +import asyncio +from typing import Optional, Dict, Any, List +from urllib.parse import urljoin +from aiolimiter import AsyncLimiter +import logging + + +class AsyncHttpClient: + """ + An asynchronous HTTP client that simplifies making requests to a specific API. + """ + ALLOWED_METHODS = ['GET', 'POST', 'PATCH', 'UPDATE', 'PUT', 'DELETE'] + def __init__( + self, + base_url: str, + retries: int = 3, + timeout: Optional[float] = None, + verify_ssl: bool = True, + retry_status_codes: Optional[List[int]] = None, + max_requests_per_second: Optional[float] = None, + default_params: Optional[Dict[str, str]] = None, + auth: Optional[tuple] = None, + auth_header: Optional[Dict[str, str]] = None, + default_headers: Optional[Dict[str, str]] = None, + backoff_factor: float = 2.0, + debug: bool = False + ): + """ + Initialize the AsyncHttpClient instance. + + Args: + base_url (str): The base URL for the API. + retries (int, optional): The maximum number of retries for failed requests. Defaults to 3. + timeout (Optional[float], optional): The request timeout in seconds. Defaults to None. + verify_ssl (bool, optional): Enable or disable SSL verification. Defaults to True. + retry_status_codes (Optional[List[int]], optional): List of status codes to retry on. Defaults to None. + max_requests_per_second (Optional[float], optional): Maximum number of requests per second. Defaults to None. + default_params (Optional[Dict[str, str]], optional): Default query parameters for each request. + auth (Optional[tuple], optional): Authentication credentials for each request. Defaults to None. + auth_header (Optional[Dict[str, str]], optional): Authentication header for each request. Defaults to None. + backoff_factor (float, optional): The backoff factor for retries. Defaults to 2.0. + """ + self.base_url = base_url if base_url.endswith("/") else base_url + "/" + self.retries = retries + self.timeout = httpx.Timeout(timeout) if timeout else None + self.verify_ssl = verify_ssl + self.retry_status_codes = retry_status_codes or [429, 500, 502, 504] + self.default_params = default_params or {} + self.auth = auth + self._auth_header = auth_header or {} + + self.limiter = None + if max_requests_per_second: + one_reqeust_per_second_amount = float(1/max_requests_per_second) + self.limiter = AsyncLimiter(1, one_reqeust_per_second_amount) + + self.default_headers = default_headers or {} + self.backoff_factor = backoff_factor + + self.client = httpx.AsyncClient(timeout=self.timeout, verify=self.verify_ssl, headers=self.default_headers, + auth=self.auth) + + if not debug: + logging.getLogger("httpx").setLevel(logging.WARNING) + logging.getLogger("httpcore").setLevel(logging.WARNING) + + async def _build_url(self, endpoint_path: Optional[str] = None, is_absolute_path=False) -> str: + # build URL Specification + url_path = str(endpoint_path).strip() if endpoint_path is not None else '' + + if not url_path: + url = self.base_url + elif is_absolute_path: + url = endpoint_path + else: + url = urljoin(self.base_url, endpoint_path) + + return url + + async def update_auth_header(self, updated_header: Dict, overwrite: bool = False): + """ + Updates the default auth header by providing new values. + + Args: + updated_header: An updated header which will be used to update the current header. + overwrite: If `False`, the existing header will be updated with new header. If `True`, the new header will + overwrite (replace) the current authentication header. + """ + + if overwrite is False: + self._auth_header.update(updated_header) + else: + self._auth_header = updated_header + + async def __aenter__(self): + await self.client.__aenter__() + return self + + async def __aexit__(self, *args): + await self.client.__aexit__(*args) + + async def close(self): + await self.client.aclose() + + async def _request( + self, + method: str, + endpoint: Optional[str] = None, + params: Optional[Dict[str, Any]] = None, + headers: Optional[Dict[str, str]] = None, + **kwargs + ) -> httpx.Response: + + + is_absolute_path = kwargs.pop('is_absolute_path', False) + url = await self._build_url(endpoint, is_absolute_path) + + all_params = {**self.default_params, **(params or {})} + + ignore_auth = kwargs.pop('ignore_auth', False) + if ignore_auth: + all_headers = {**self.default_headers, **(headers or {})} + else: + all_headers = {**self._auth_header, **self.default_headers, **(headers or {})} + if self.auth: + kwargs.update({'auth': self.auth}) + + if all_params: + kwargs.update({'params': all_params}) + if all_headers: + kwargs.update({'headers': all_headers}) + + response = None + + for retry_attempt in range(self.retries + 1): + try: + if self.limiter: + async with self.limiter: + response = await self.client.request(method, url=url, **kwargs) + else: + response = await self.client.request(method, url=url, **kwargs) + + response.raise_for_status() + + return response + + except httpx.HTTPError as e: + + if not isinstance(e, httpx.ReadTimeout): + message = response.text if response and response.text else str(e) + e.args = (f"Error '{e.response.status_code} {message}' for url '{e.request.url}'",) + + if response: + if response.status_code not in self.retry_status_codes: + raise + + if retry_attempt == self.retries: + raise + backoff = self.backoff_factor ** retry_attempt + await asyncio.sleep(backoff) + + logging.error(f"Retry attempt {retry_attempt + 1} for {method} request to {url}: {message}") + + async def get(self, endpoint: Optional[str] = None, **kwargs) -> Dict[str, Any]: + response = await self.get_raw(endpoint, **kwargs) + return response.json() + + async def get_raw(self, endpoint: Optional[str] = None, **kwargs) -> httpx.Response: + return await self._request("GET", endpoint, **kwargs) + + async def post(self, endpoint: Optional[str] = None, **kwargs) -> Dict[str, Any]: + response = await self.post_raw(endpoint, **kwargs) + return response.json() + + async def post_raw(self, endpoint: Optional[str] = None, **kwargs) -> httpx.Response: + return await self._request("POST", endpoint, **kwargs) + + async def put(self, endpoint: Optional[str] = None, **kwargs) -> Dict[str, Any]: + response = await self.put_raw(endpoint, **kwargs) + return response.json() + + async def put_raw(self, endpoint: Optional[str] = None, **kwargs) -> httpx.Response: + return await self._request("PUT", endpoint, **kwargs) + + async def patch(self, endpoint: Optional[str] = None, **kwargs) -> Dict[str, Any]: + response = await self.patch_raw(endpoint, **kwargs) + return response.json() + + async def patch_raw(self, endpoint: Optional[str] = None, **kwargs) -> httpx.Response: + return await self._request("PATCH", endpoint, **kwargs) + + async def delete(self, endpoint: Optional[str] = None, **kwargs) -> Dict[str, Any]: + response = await self.delete_raw(endpoint, **kwargs) + return response.json() + + async def delete_raw(self, endpoint: Optional[str] = None, **kwargs) -> httpx.Response: + return await self._request("DELETE", endpoint, **kwargs) + + async def process_multiple(self, jobs: List[Dict[str, Any]]): + tasks = [] + + for job in jobs: + method = job['method'] + endpoint = job['endpoint'] + params = job.get('params') + headers = job.get('headers') + raw = job.get('raw', False) + + if method == 'GET': + if raw: + task = self.get_raw(endpoint, params=params, headers=headers) + else: + task = self.get(endpoint, params=params, headers=headers) + elif method == 'POST': + if raw: + task = self.post_raw(endpoint, params=params, headers=headers) + else: + task = self.post(endpoint, params=params, headers=headers) + elif method == 'PUT': + if raw: + task = self.put_raw(endpoint, params=params, headers=headers) + else: + task = self.put(endpoint, params=params, headers=headers) + elif method == 'PATCH': + if raw: + task = self.patch_raw(endpoint, params=params, headers=headers) + else: + task = self.patch(endpoint, params=params, headers=headers) + elif method == 'DELETE': + if raw: + task = self.delete_raw(endpoint, params=params, headers=headers) + else: + task = self.delete(endpoint, params=params, headers=headers) + else: + raise ValueError(f"Unsupported method: {method}") + + tasks.append(task) + + responses = await asyncio.gather(*tasks) + return responses diff --git a/src/keboola/http_client/http.py b/src/keboola/http_client/client.py similarity index 99% rename from src/keboola/http_client/http.py rename to src/keboola/http_client/client.py index 8f9e1b8..d465391 100644 --- a/src/keboola/http_client/http.py +++ b/src/keboola/http_client/client.py @@ -6,7 +6,7 @@ import requests from requests.adapters import HTTPAdapter -from requests.packages.urllib3.util.retry import Retry +from requests.packages.urllib3.util.retry import Retry # noqa Cookie = Union[Dict[str, str], CookieJar] diff --git a/tests/test_async.py b/tests/test_async.py new file mode 100644 index 0000000..a77d921 --- /dev/null +++ b/tests/test_async.py @@ -0,0 +1,310 @@ +import unittest +from unittest.mock import patch +import httpx +from keboola.http_client import AsyncHttpClient + + +class TestAsyncHttpClient(unittest.IsolatedAsyncioTestCase): + base_url = "https://api.example.com" + retries = 3 + + async def test_get(self): + expected_response = {"message": "Success"} + mock_response = httpx.Response(200, json=expected_response) + mock_response._request = httpx.Request("GET", "https://api.example.com/endpoint") + + client = AsyncHttpClient(self.base_url, retries=self.retries) + + with patch.object(httpx.AsyncClient, 'request', return_value=mock_response) as mock_request: + response = await client.get("/endpoint") + self.assertEqual(response, expected_response) + mock_request.assert_called_once_with("GET", url="https://api.example.com/endpoint") + + async def test_post(self): + expected_response = {"message": "Success"} + mock_response = httpx.Response(200, json=expected_response) + mock_response._request = httpx.Request("POST", "https://api.example.com/endpoint") + + client = AsyncHttpClient(self.base_url, retries=self.retries) + + with patch.object(httpx.AsyncClient, 'request', return_value=mock_response) as mock_request: + response = await client.post("/endpoint", json={"data": "example"}) + self.assertEqual(response, expected_response) + mock_request.assert_called_once_with("POST", url="https://api.example.com/endpoint", + json={"data": "example"}) + + async def test_handle_success_response(self): + expected_response = {"message": "Success"} + mock_response = httpx.Response(200, json=expected_response) + mock_response._request = httpx.Request("GET", "https://api.example.com/endpoint") + + client = AsyncHttpClient(self.base_url, retries=self.retries) + + with patch.object(httpx.AsyncClient, 'request', return_value=mock_response) as mock_request: + response = await client.get("/endpoint") + self.assertEqual(response, expected_response) + mock_request.assert_called_once_with("GET", url="https://api.example.com/endpoint") + + async def test_handle_client_error_response(self): + mock_response = httpx.Response(404) + mock_response._request = httpx.Request("GET", "https://api.example.com/endpoint") + + client = AsyncHttpClient(self.base_url, retries=self.retries, retry_status_codes=[404]) + + with patch.object(httpx.AsyncClient, 'request', return_value=mock_response) as mock_request: + with self.assertRaises(httpx.HTTPStatusError): + await client.get("/endpoint") + + assert mock_request.call_count == self.retries + 1 + + mock_request.assert_called_with("GET", url="https://api.example.com/endpoint") + + async def test_handle_server_error_response(self): + mock_response = httpx.Response(500) + mock_response._request = httpx.Request("GET", "https://api.example.com/endpoint") + + client = AsyncHttpClient(self.base_url, retries=self.retries, retry_status_codes=[500]) + + with patch.object(httpx.AsyncClient, 'request', return_value=mock_response) as mock_request: + with self.assertRaises(httpx.HTTPStatusError): + await client.get("/endpoint") + + assert mock_request.call_count == self.retries + 1 + + mock_request.assert_called_with("GET", url="https://api.example.com/endpoint") + + @patch.object(httpx.AsyncClient, 'request') + async def test_post_raw_default_pars_with_none_custom_pars_passes(self, mock_request): + url = f"{self.base_url}/endpoint" + test_def_par = {"default_par": "test"} + + client = AsyncHttpClient(self.base_url, retries=self.retries) + + await client.post_raw("/endpoint", params=test_def_par) + + mock_request.assert_called_once_with("POST", url=url, params=test_def_par) + + @patch.object(httpx.AsyncClient, 'request') + async def test_post_default_pars_with_none_custom_pars_passes(self, mock_request): + url = f"{self.base_url}/endpoint" + test_def_par = {"default_par": "test"} + + client = AsyncHttpClient(self.base_url, retries=self.retries) + + await client.post("/endpoint", params=test_def_par) + + mock_request.assert_called_once_with("POST", url=url, params=test_def_par) + + @patch.object(httpx.AsyncClient, 'request') + async def test_post_raw_default_pars_with_custom_pars_passes(self, mock_request): + url = f"{self.base_url}/endpoint" + test_def_par = {"default_par": "test"} + cust_par = {"custom_par": "custom_par_value"} + + client = AsyncHttpClient(self.base_url, retries=self.retries, default_params=test_def_par) + + await client.post_raw("/endpoint", params=cust_par) + + test_cust_def_par = {**test_def_par, **cust_par} + mock_request.assert_called_once_with("POST", url=url, params=test_cust_def_par) + + @patch.object(httpx.AsyncClient, 'request') + async def test_post_default_pars_with_custom_pars_passes(self, mock_request): + url = f"{self.base_url}/endpoint" + test_def_par = {"default_par": "test"} + cust_par = {"custom_par": "custom_par_value"} + + client = AsyncHttpClient(self.base_url, retries=self.retries, default_params=test_def_par) + + await client.post("/endpoint", params=cust_par) + + test_cust_def_par = {**test_def_par, **cust_par} + mock_request.assert_called_once_with("POST", url=url, params=test_cust_def_par) + + @patch.object(httpx.AsyncClient, 'request') + async def test_post_raw_default_pars_with_custom_pars_to_None_passes(self, mock_request): + url = f"{self.base_url}/endpoint" + test_def_par = {"default_par": "test"} + cust_par = None + + client = AsyncHttpClient(self.base_url, retries=self.retries, default_params=test_def_par) + + await client.post_raw("/endpoint", params=cust_par) + + # post_raw changes None to empty dict + _cust_par_transformed = {} + test_cust_def_par = {**test_def_par, **_cust_par_transformed} + mock_request.assert_called_once_with("POST", url=url, params=test_cust_def_par) + + @patch.object(httpx.AsyncClient, 'request') + async def test_post_default_pars_with_custom_pars_to_None_passes(self, mock_request): + url = f"{self.base_url}/endpoint" + test_def_par = {"default_par": "test"} + cust_par = None + + client = AsyncHttpClient(self.base_url, retries=self.retries, default_params=test_def_par) + + await client.post("/endpoint", params=cust_par) + + # post_raw changes None to empty dict + _cust_par_transformed = {} + test_cust_def_par = {**test_def_par, **_cust_par_transformed} + mock_request.assert_called_once_with("POST", url=url, params=test_cust_def_par) + + @patch.object(httpx.AsyncClient, 'request') + async def test_all_methods_requests_raw_with_custom_pars_passes(self, mock_request): + client = AsyncHttpClient(self.base_url) + + cust_par = {"custom_par": "custom_par_value"} + + for m in client.ALLOWED_METHODS: + await client._request(m, ignore_auth=False, params=cust_par) + mock_request.assert_called_with(m, url=self.base_url+"/", params=cust_par) + + @patch.object(httpx.AsyncClient, 'request') + async def test_all_methods_skip_auth(self, mock_request): + client = AsyncHttpClient(self.base_url, auth=("my_user", "password123")) + + for m in ['GET', 'POST', 'PATCH', 'UPDATE', 'PUT', 'DELETE']: + await client._request(m, ignore_auth=True) + mock_request.assert_called_with(m, url=self.base_url+"/") + + @patch.object(httpx.AsyncClient, 'request') + async def test_request_skip_auth_header(self, mock_request): + def_header = {"def_header": "test"} + client = AsyncHttpClient('http://example.com', default_headers=def_header, + auth_header={"Authorization": "test"}) + + await client._request('POST', 'abc', ignore_auth=True) + mock_request.assert_called_with('POST', url="http://example.com/abc", headers=def_header) + + @patch.object(httpx.AsyncClient, 'request') + async def test_request_auth(self, mock_request): + def_header = {"def_header": "test"} + auth = ("my_user", "password123") + client = AsyncHttpClient(self.base_url, auth=auth, default_headers=def_header) + + await client._request('POST', 'abc') + mock_request.assert_called_with('POST', url=self.base_url+"/abc", headers=def_header, + auth=auth) + + @patch.object(httpx.AsyncClient, 'request') + async def test_all_methods(self, mock_request): + client = AsyncHttpClient(self.base_url, default_headers={'header1': 'headerval'}, + auth_header={'api_token': 'abdc1234'}) + + target_url = f'{self.base_url}/abc' + + for m in client.ALLOWED_METHODS: + await client._request(m, 'abc', params={'exclude': 'componentDetails'}, headers={'abc': '123'}, + data={'attr1': 'val1'}) + mock_request.assert_called_with(m, url=target_url, + params={'exclude': 'componentDetails'}, + headers={'api_token': 'abdc1234', 'header1': 'headerval', 'abc': '123'}, + data={'attr1': 'val1'}) + + @patch.object(httpx.AsyncClient, 'request') + async def test_all_methods_requests_raw_with_is_absolute_path_true(self, mock_request): + def_header = {"def_header": "test"} + client = AsyncHttpClient(self.base_url, default_headers=def_header) + + for m in client.ALLOWED_METHODS: + await client._request(m, 'http://example2.com/v1/', is_absolute_path=True) + mock_request.assert_called_with(m, url='http://example2.com/v1/', headers=def_header) + + @patch.object(httpx.AsyncClient, 'request') + async def test_all_methods_requests_raw_with_is_absolute_path_false(self, mock_request): + def_header = {"def_header": "test"} + client = AsyncHttpClient(self.base_url, default_headers=def_header) + + for m in client.ALLOWED_METHODS: + await client._request(m, 'cars') + mock_request.assert_called_with(m, url=self.base_url+"/cars", headers=def_header) + + @patch.object(httpx.AsyncClient, 'request') + async def test_all_methods_kwargs(self, mock_request): + client = AsyncHttpClient(self.base_url) + + for m in client.ALLOWED_METHODS: + await client._request(m, 'cars', data={'data': '123'}, cert='/path/to/cert', files={'a': '/path/to/file'}, + params={'par1': 'val1'}) + + mock_request.assert_called_with(m, url=self.base_url+"/cars", data={'data': '123'}, + cert='/path/to/cert', files={'a': '/path/to/file'}, + params={'par1': 'val1'}) + + async def test_build_url_rel_path(self): + url = 'https://example.com/' + cl = AsyncHttpClient(url) + expected_url = 'https://example.com/storage' + actual_url = await cl._build_url('storage') + self.assertEqual(expected_url, actual_url) + + async def test_build_url_abs_path(self): + url = 'https://example.com/' + cl = AsyncHttpClient(url) + expected_url = 'https://example2.com/storage' + actual_url = await cl._build_url('https://example2.com/storage', True) + self.assertEqual(expected_url, actual_url) + + async def test_build_url_empty_endpoint_path_leads_to_base_url(self): + url = 'https://example.com/' + cl = AsyncHttpClient(url) + expected_url = url + + actual_url = await cl._build_url() + self.assertEqual(expected_url, actual_url) + + actual_url = await cl._build_url('') + self.assertEqual(expected_url, actual_url) + + actual_url = await cl._build_url(None) + self.assertEqual(expected_url, actual_url) + + actual_url = await cl._build_url('', is_absolute_path=True) + self.assertEqual(expected_url, actual_url) + + actual_url = await cl._build_url(None, is_absolute_path=True) + self.assertEqual(expected_url, actual_url) + + async def test_build_url_base_url_appends_slash(self): + url = 'https://example.com' + cl = AsyncHttpClient(url) + expected_base_url = 'https://example.com/' + + self.assertEqual(expected_base_url, cl.base_url) + + async def test_update_auth_header_None(self): + existing_header = None + new_header = {'api_token': 'token_value'} + + cl = AsyncHttpClient('https://example.com', auth_header=existing_header) + await cl.update_auth_header(new_header, overwrite=False) + self.assertDictEqual(cl._auth_header, new_header) + + new_header_2 = {'password': '123'} + await cl.update_auth_header(new_header_2, overwrite=True) + self.assertDictEqual(cl._auth_header, new_header_2) + + async def test_update_existing_auth_header(self): + existing_header = {'authorization': 'value'} + new_header = {'api_token': 'token_value'} + + cl = AsyncHttpClient('https://example.com', auth_header=existing_header) + await cl.update_auth_header(new_header, overwrite=False) + self.assertDictEqual(cl._auth_header, {**existing_header, **new_header}) + + async def test_detailed_exception(self): + mock_response = httpx.Response(404, text="Not Found Because of x") + mock_response._request = httpx.Request("GET", "https://api.example.com/endpoint") + + client = AsyncHttpClient(self.base_url) + + with patch.object(httpx.AsyncClient, 'request', return_value=mock_response) as mock_request: + with self.assertRaises(httpx.HTTPStatusError) as e: + await client.get("/endpoint") + + assert "Error '404 Not Found Because of x' for url 'https://api.example.com/endpoint'" in str(e.exception) + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_http.py b/tests/test_http.py index 99dc66d..82752b9 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -2,7 +2,7 @@ import urllib.parse as urlparse from unittest.mock import patch -import keboola.http_client.http as client +import keboola.http_client.client as client class TestClientBase(unittest.TestCase):