diff --git a/providers/dbt/cloud/pyproject.toml b/providers/dbt/cloud/pyproject.toml index de91020daabc9..6a1bd3804023f 100644 --- a/providers/dbt/cloud/pyproject.toml +++ b/providers/dbt/cloud/pyproject.toml @@ -62,6 +62,7 @@ dependencies = [ "apache-airflow-providers-http", "asgiref>=2.3.0", "aiohttp>=3.9.2", + "tenacity>=8.3.0", ] # The optional dependencies should be modified in place in the generated file diff --git a/providers/dbt/cloud/src/airflow/providers/dbt/cloud/hooks/dbt.py b/providers/dbt/cloud/src/airflow/providers/dbt/cloud/hooks/dbt.py index c7023f1b9230a..70630510968f0 100644 --- a/providers/dbt/cloud/src/airflow/providers/dbt/cloud/hooks/dbt.py +++ b/providers/dbt/cloud/src/airflow/providers/dbt/cloud/hooks/dbt.py @@ -17,6 +17,7 @@ from __future__ import annotations import asyncio +import copy import json import time import warnings @@ -28,8 +29,10 @@ import aiohttp from asgiref.sync import sync_to_async +from requests import exceptions as requests_exceptions from requests.auth import AuthBase from requests.sessions import Session +from tenacity import AsyncRetrying, RetryCallState, retry_if_exception, stop_after_attempt, wait_exponential from airflow.exceptions import AirflowException from airflow.providers.http.hooks.http import HttpHook @@ -174,6 +177,10 @@ class DbtCloudHook(HttpHook): Interact with dbt Cloud using the V2 (V3 if supported) API. :param dbt_cloud_conn_id: The ID of the :ref:`dbt Cloud connection `. + :param timeout_seconds: Optional. The timeout in seconds for HTTP requests. If not provided, no timeout is applied. + :param retry_limit: The number of times to retry a request in case of failure. + :param retry_delay: The delay in seconds between retries. + :param retry_args: A dictionary of arguments to pass to the `tenacity.retry` decorator. """ conn_name_attr = "dbt_cloud_conn_id" @@ -193,9 +200,39 @@ def get_ui_field_behaviour(cls) -> dict[str, Any]: }, } - def __init__(self, dbt_cloud_conn_id: str = default_conn_name, *args, **kwargs) -> None: + def __init__( + self, + dbt_cloud_conn_id: str = default_conn_name, + timeout_seconds: int | None = None, + retry_limit: int = 1, + retry_delay: float = 1.0, + retry_args: dict[Any, Any] | None = None, + ) -> None: super().__init__(auth_type=TokenAuth) self.dbt_cloud_conn_id = dbt_cloud_conn_id + self.timeout_seconds = timeout_seconds + if retry_limit < 1: + raise ValueError("Retry limit must be greater than or equal to 1") + self.retry_limit = retry_limit + self.retry_delay = retry_delay + + def retry_after_func(retry_state: RetryCallState) -> None: + error_msg = str(retry_state.outcome.exception()) if retry_state.outcome else "Unknown error" + self._log_request_error(retry_state.attempt_number, error_msg) + + if retry_args: + self.retry_args = copy.copy(retry_args) + self.retry_args["retry"] = retry_if_exception(self._retryable_error) + self.retry_args["after"] = retry_after_func + self.retry_args["reraise"] = True + else: + self.retry_args = { + "stop": stop_after_attempt(self.retry_limit), + "wait": wait_exponential(min=self.retry_delay, max=(2**retry_limit)), + "retry": retry_if_exception(self._retryable_error), + "after": retry_after_func, + "reraise": True, + } @staticmethod def _get_tenant_domain(conn: Connection) -> str: @@ -233,6 +270,36 @@ async def get_headers_tenants_from_connection(self) -> tuple[dict[str, Any], str headers["Authorization"] = f"Token {self.connection.password}" return headers, tenant + def _log_request_error(self, attempt_num: int, error: str) -> None: + self.log.error("Attempt %s API Request to DBT failed with reason: %s", attempt_num, error) + + @staticmethod + def _retryable_error(exception: BaseException) -> bool: + if isinstance(exception, requests_exceptions.RequestException): + if isinstance(exception, (requests_exceptions.ConnectionError, requests_exceptions.Timeout)) or ( + exception.response is not None + and (exception.response.status_code >= 500 or exception.response.status_code == 429) + ): + return True + + if isinstance(exception, aiohttp.ClientResponseError): + if exception.status >= 500 or exception.status == 429: + return True + + if isinstance(exception, (aiohttp.ClientConnectorError, TimeoutError)): + return True + + return False + + def _a_get_retry_object(self) -> AsyncRetrying: + """ + Instantiate an async retry object. + + :return: instance of AsyncRetrying class + """ + # for compatibility we use reraise to avoid handling request error + return AsyncRetrying(**self.retry_args) + @provide_account_id async def get_job_details( self, run_id: int, account_id: int | None = None, include_related: list[str] | None = None @@ -249,17 +316,22 @@ async def get_job_details( headers, tenant = await self.get_headers_tenants_from_connection() url, params = self.get_request_url_params(tenant, endpoint, include_related) proxies = self._get_proxies(self.connection) or {} + proxy = proxies.get("https") if proxies and url.startswith("https") else proxies.get("http") + extra_request_args = {} - async with aiohttp.ClientSession(headers=headers) as session: - proxy = proxies.get("https") if proxies and url.startswith("https") else proxies.get("http") - extra_request_args = {} + if proxy: + extra_request_args["proxy"] = proxy - if proxy: - extra_request_args["proxy"] = proxy + timeout = ( + aiohttp.ClientTimeout(total=self.timeout_seconds) if self.timeout_seconds is not None else None + ) - async with session.get(url, params=params, **extra_request_args) as response: # type: ignore[arg-type] - response.raise_for_status() - return await response.json() + async with aiohttp.ClientSession(headers=headers, timeout=timeout) as session: + async for attempt in self._a_get_retry_object(): + with attempt: + async with session.get(url, params=params, **extra_request_args) as response: # type: ignore[arg-type] + response.raise_for_status() + return await response.json() async def get_job_status( self, run_id: int, account_id: int | None = None, include_related: list[str] | None = None @@ -297,8 +369,14 @@ def get_conn(self, *args, **kwargs) -> Session: def _paginate( self, endpoint: str, payload: dict[str, Any] | None = None, proxies: dict[str, str] | None = None ) -> list[Response]: - extra_options = {"proxies": proxies} if proxies is not None else None - response = self.run(endpoint=endpoint, data=payload, extra_options=extra_options) + extra_options: dict[str, Any] = {} + if self.timeout_seconds is not None: + extra_options["timeout"] = self.timeout_seconds + if proxies is not None: + extra_options["proxies"] = proxies + response = self.run_with_advanced_retry( + _retry_args=self.retry_args, endpoint=endpoint, data=payload, extra_options=extra_options or None + ) resp_json = response.json() limit = resp_json["extra"]["filters"]["limit"] num_total_results = resp_json["extra"]["pagination"]["total_count"] @@ -309,7 +387,12 @@ def _paginate( _paginate_payload["offset"] = limit while num_current_results < num_total_results: - response = self.run(endpoint=endpoint, data=_paginate_payload, extra_options=extra_options) + response = self.run_with_advanced_retry( + _retry_args=self.retry_args, + endpoint=endpoint, + data=_paginate_payload, + extra_options=extra_options, + ) resp_json = response.json() results.append(response) num_current_results += resp_json["extra"]["pagination"]["count"] @@ -328,7 +411,11 @@ def _run_and_get_response( self.method = method full_endpoint = f"api/{api_version}/accounts/{endpoint}" if endpoint else None proxies = self._get_proxies(self.connection) - extra_options = {"proxies": proxies} if proxies is not None else None + extra_options: dict[str, Any] = {} + if self.timeout_seconds is not None: + extra_options["timeout"] = self.timeout_seconds + if proxies is not None: + extra_options["proxies"] = proxies if paginate: if isinstance(payload, str): @@ -339,7 +426,12 @@ def _run_and_get_response( raise ValueError("An endpoint is needed to paginate a response.") - return self.run(endpoint=full_endpoint, data=payload, extra_options=extra_options) + return self.run_with_advanced_retry( + _retry_args=self.retry_args, + endpoint=full_endpoint, + data=payload, + extra_options=extra_options or None, + ) def list_accounts(self) -> list[Response]: """ diff --git a/providers/dbt/cloud/src/airflow/providers/dbt/cloud/operators/dbt.py b/providers/dbt/cloud/src/airflow/providers/dbt/cloud/operators/dbt.py index 03492a0b23c5c..3ac8c6d544fd5 100644 --- a/providers/dbt/cloud/src/airflow/providers/dbt/cloud/operators/dbt.py +++ b/providers/dbt/cloud/src/airflow/providers/dbt/cloud/operators/dbt.py @@ -87,6 +87,7 @@ class DbtCloudRunJobOperator(BaseOperator): run. For more information on retry logic, see: https://docs.getdbt.com/dbt-cloud/api-v2#/operations/Retry%20Failed%20Job :param deferrable: Run operator in the deferrable mode + :param hook_params: Extra arguments passed to the DbtCloudHook constructor. :return: The ID of the triggered dbt Cloud job run. """ @@ -124,6 +125,7 @@ def __init__( reuse_existing_run: bool = False, retry_from_failure: bool = False, deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + hook_params: dict[str, Any] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -144,6 +146,7 @@ def __init__( self.reuse_existing_run = reuse_existing_run self.retry_from_failure = retry_from_failure self.deferrable = deferrable + self.hook_params = hook_params or {} def execute(self, context: Context): if self.trigger_reason is None: @@ -273,7 +276,7 @@ def on_kill(self) -> None: @cached_property def hook(self): """Returns DBT Cloud hook.""" - return DbtCloudHook(self.dbt_cloud_conn_id) + return DbtCloudHook(self.dbt_cloud_conn_id, **self.hook_params) def get_openlineage_facets_on_complete(self, task_instance) -> OperatorLineage: """ @@ -311,6 +314,7 @@ class DbtCloudGetJobRunArtifactOperator(BaseOperator): be returned. :param output_file_name: Optional. The desired file name for the download artifact file. Defaults to _ (e.g. "728368_run_results.json"). + :param hook_params: Extra arguments passed to the DbtCloudHook constructor. """ template_fields = ("dbt_cloud_conn_id", "run_id", "path", "account_id", "output_file_name") @@ -324,6 +328,7 @@ def __init__( account_id: int | None = None, step: int | None = None, output_file_name: str | None = None, + hook_params: dict[str, Any] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -333,9 +338,10 @@ def __init__( self.account_id = account_id self.step = step self.output_file_name = output_file_name or f"{self.run_id}_{self.path}".replace("/", "-") + self.hook_params = hook_params or {} def execute(self, context: Context) -> str: - hook = DbtCloudHook(self.dbt_cloud_conn_id) + hook = DbtCloudHook(self.dbt_cloud_conn_id, **self.hook_params) response = hook.get_job_run_artifact( run_id=self.run_id, path=self.path, account_id=self.account_id, step=self.step ) @@ -370,6 +376,7 @@ class DbtCloudListJobsOperator(BaseOperator): :param order_by: Optional. Field to order the result by. Use '-' to indicate reverse order. For example, to use reverse order by the run ID use ``order_by=-id``. :param project_id: Optional. The ID of a dbt Cloud project. + :param hook_params: Extra arguments passed to the DbtCloudHook constructor. """ template_fields = ( @@ -384,6 +391,7 @@ def __init__( account_id: int | None = None, project_id: int | None = None, order_by: str | None = None, + hook_params: dict[str, Any] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -391,9 +399,10 @@ def __init__( self.account_id = account_id self.project_id = project_id self.order_by = order_by + self.hook_params = hook_params or {} def execute(self, context: Context) -> list: - hook = DbtCloudHook(self.dbt_cloud_conn_id) + hook = DbtCloudHook(self.dbt_cloud_conn_id, **self.hook_params) list_jobs_response = hook.list_jobs( account_id=self.account_id, order_by=self.order_by, project_id=self.project_id ) diff --git a/providers/dbt/cloud/src/airflow/providers/dbt/cloud/triggers/dbt.py b/providers/dbt/cloud/src/airflow/providers/dbt/cloud/triggers/dbt.py index c8acf2d81e647..9d4c59473b1f0 100644 --- a/providers/dbt/cloud/src/airflow/providers/dbt/cloud/triggers/dbt.py +++ b/providers/dbt/cloud/src/airflow/providers/dbt/cloud/triggers/dbt.py @@ -36,6 +36,7 @@ class DbtCloudRunJobTrigger(BaseTrigger): :param end_time: Time in seconds to wait for a job run to reach a terminal status. Defaults to 7 days. :param account_id: The ID of a dbt Cloud account. :param poll_interval: polling period in seconds to check for the status. + :param hook_params: Extra arguments passed to the DbtCloudHook constructor. """ def __init__( @@ -45,6 +46,7 @@ def __init__( end_time: float, poll_interval: float, account_id: int | None, + hook_params: dict[str, Any] | None = None, ): super().__init__() self.run_id = run_id @@ -52,6 +54,7 @@ def __init__( self.conn_id = conn_id self.end_time = end_time self.poll_interval = poll_interval + self.hook_params = hook_params or {} def serialize(self) -> tuple[str, dict[str, Any]]: """Serialize DbtCloudRunJobTrigger arguments and classpath.""" @@ -63,12 +66,13 @@ def serialize(self) -> tuple[str, dict[str, Any]]: "conn_id": self.conn_id, "end_time": self.end_time, "poll_interval": self.poll_interval, + "hook_params": self.hook_params, }, ) async def run(self) -> AsyncIterator[TriggerEvent]: """Make async connection to Dbt, polls for the pipeline run status.""" - hook = DbtCloudHook(self.conn_id) + hook = DbtCloudHook(self.conn_id, **self.hook_params) try: while await self.is_still_running(hook): if self.end_time < time.time(): diff --git a/providers/dbt/cloud/tests/unit/dbt/cloud/hooks/test_dbt.py b/providers/dbt/cloud/tests/unit/dbt/cloud/hooks/test_dbt.py index 08bc9c6975704..607bb650871fb 100644 --- a/providers/dbt/cloud/tests/unit/dbt/cloud/hooks/test_dbt.py +++ b/providers/dbt/cloud/tests/unit/dbt/cloud/hooks/test_dbt.py @@ -20,9 +20,11 @@ from copy import deepcopy from datetime import timedelta from typing import Any -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch +import aiohttp import pytest +from requests import exceptions as requests_exceptions from requests.models import Response from airflow.exceptions import AirflowException @@ -95,6 +97,12 @@ def mock_response_json(response: dict): return run_response +def request_exception_with_status(status_code: int) -> requests_exceptions.HTTPError: + response = Response() + response.status_code = status_code + return requests_exceptions.HTTPError(response=response) + + class TestDbtCloudJobRunStatus: valid_job_run_statuses = [ 1, # QUEUED @@ -1072,3 +1080,235 @@ def test_connection_failure(self, requests_mock, conn_id): assert status is False assert msg == "403:Authentication credentials were not provided" + + @pytest.mark.parametrize( + argnames="timeout_seconds", + argvalues=[60, 180, 300], + ids=["60s", "180s", "300s"], + ) + @patch.object(DbtCloudHook, "run_with_advanced_retry") + def test_timeout_passed_to_run_and_get_response(self, mock_run_with_retry, timeout_seconds): + """Test that timeout is passed to extra_options in _run_and_get_response.""" + hook = DbtCloudHook(ACCOUNT_ID_CONN, timeout_seconds=timeout_seconds) + mock_run_with_retry.return_value = mock_response_json({"data": {"id": JOB_ID}}) + + hook.get_job(job_id=JOB_ID, account_id=DEFAULT_ACCOUNT_ID) + + call_args = mock_run_with_retry.call_args + assert call_args is not None + extra_options = call_args.kwargs.get("extra_options") + assert extra_options is not None + assert extra_options["timeout"] == timeout_seconds + + @pytest.mark.parametrize( + argnames="timeout_seconds", + argvalues=[60, 180, 300], + ids=["60s", "180s", "300s"], + ) + @patch.object(DbtCloudHook, "run_with_advanced_retry") + def test_timeout_passed_to_paginate(self, mock_run_with_retry, timeout_seconds): + """Test that timeout is passed to extra_options in _paginate.""" + hook = DbtCloudHook(ACCOUNT_ID_CONN, timeout_seconds=timeout_seconds) + mock_response = mock_response_json( + { + "data": [{"id": JOB_ID}], + "extra": {"filters": {"limit": 100}, "pagination": {"count": 1, "total_count": 1}}, + } + ) + mock_run_with_retry.return_value = mock_response + + hook.list_jobs(account_id=DEFAULT_ACCOUNT_ID) + + call_args = mock_run_with_retry.call_args + assert call_args is not None + extra_options = call_args.kwargs.get("extra_options") + assert extra_options is not None + assert extra_options["timeout"] == timeout_seconds + + @pytest.mark.parametrize( + argnames="timeout_seconds", + argvalues=[60, 180, 300], + ids=["60s", "180s", "300s"], + ) + @patch.object(DbtCloudHook, "run_with_advanced_retry") + def test_timeout_with_proxies(self, mock_run_with_retry, timeout_seconds): + """Test that both timeout and proxies are passed to extra_options.""" + hook = DbtCloudHook(PROXY_CONN, timeout_seconds=timeout_seconds) + mock_run_with_retry.return_value = mock_response_json({"data": {"id": JOB_ID}}) + + hook.get_job(job_id=JOB_ID, account_id=DEFAULT_ACCOUNT_ID) + + call_args = mock_run_with_retry.call_args + assert call_args is not None + extra_options = call_args.kwargs.get("extra_options") + assert extra_options is not None + assert extra_options["timeout"] == timeout_seconds + assert "proxies" in extra_options + assert extra_options["proxies"] == EXTRA_PROXIES["proxies"] + + @pytest.mark.parametrize( + argnames="exception, expected", + argvalues=[ + (requests_exceptions.ConnectionError(), True), + (requests_exceptions.Timeout(), True), + (request_exception_with_status(503), True), + (request_exception_with_status(429), True), + (request_exception_with_status(404), False), + (aiohttp.ClientResponseError(MagicMock(), (), status=500, message=""), True), + (aiohttp.ClientResponseError(MagicMock(), (), status=429, message=""), True), + (aiohttp.ClientResponseError(MagicMock(), (), status=400, message=""), False), + (aiohttp.ClientConnectorError(MagicMock(), OSError()), True), + (TimeoutError(), True), + (ValueError(), False), + ], + ids=[ + "requests_connection_error", + "requests_timeout", + "requests_status_503", + "requests_status_429", + "requests_status_404", + "aiohttp_status_500", + "aiohttp_status_429", + "aiohttp_status_400", + "aiohttp_connector_error", + "timeout_error", + "value_error", + ], + ) + def test_retryable_error(self, exception, expected): + assert DbtCloudHook._retryable_error(exception) is expected + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "error_factory, retry_qty, retry_delay", + [ + ( + lambda: aiohttp.ClientResponseError( + request_info=AsyncMock(), history=(), status=500, message="" + ), + 3, + 0.1, + ), + ( + lambda: aiohttp.ClientResponseError( + request_info=AsyncMock(), history=(), status=429, message="" + ), + 5, + 0.1, + ), + (lambda: aiohttp.ClientConnectorError(AsyncMock(), OSError("boom")), 2, 0.1), + (lambda: TimeoutError(), 2, 0.1), + ], + ids=["aiohttp_500", "aiohttp_429", "connector_error", "timeout"], + ) + @patch("airflow.providers.dbt.cloud.hooks.dbt.aiohttp.ClientSession.get") + async def test_get_job_details_retry_with_retryable_errors( + self, get_mock, error_factory, retry_qty, retry_delay + ): + hook = DbtCloudHook(ACCOUNT_ID_CONN, retry_limit=retry_qty, retry_delay=retry_delay) + + def fail_cm(): + cm = AsyncMock() + cm.__aenter__.side_effect = error_factory() + return cm + + ok_resp = AsyncMock() + ok_resp.raise_for_status = MagicMock(return_value=None) + ok_resp.json = AsyncMock(return_value={"data": "Success"}) + ok_cm = AsyncMock() + ok_cm.__aenter__.return_value = ok_resp + ok_cm.__aexit__.return_value = AsyncMock() + + all_resp = [fail_cm() for _ in range(retry_qty - 1)] + all_resp.append(ok_cm) + get_mock.side_effect = all_resp + + result = await hook.get_job_details(run_id=RUN_ID, account_id=None) + + assert result == {"data": "Success"} + assert get_mock.call_count == retry_qty + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "error_factory, expected_exception", + [ + ( + lambda: aiohttp.ClientResponseError( + request_info=AsyncMock(), history=(), status=404, message="Not Found" + ), + aiohttp.ClientResponseError, + ), + ( + lambda: aiohttp.ClientResponseError( + request_info=AsyncMock(), history=(), status=400, message="Bad Request" + ), + aiohttp.ClientResponseError, + ), + (lambda: ValueError("Invalid parameter"), ValueError), + ], + ids=["aiohttp_404", "aiohttp_400", "value_error"], + ) + @patch("airflow.providers.dbt.cloud.hooks.dbt.aiohttp.ClientSession.get") + async def test_get_job_details_retry_with_non_retryable_errors( + self, get_mock, error_factory, expected_exception + ): + hook = DbtCloudHook(ACCOUNT_ID_CONN, retry_limit=3, retry_delay=0.1) + + def fail_cm(): + cm = AsyncMock() + cm.__aenter__.side_effect = error_factory() + return cm + + get_mock.return_value = fail_cm() + + with pytest.raises(expected_exception): + await hook.get_job_details(run_id=RUN_ID, account_id=None) + + assert get_mock.call_count == 1 + + @pytest.mark.asyncio + @pytest.mark.parametrize( + argnames="error_factory, expected_exception", + argvalues=[ + ( + lambda: aiohttp.ClientResponseError( + request_info=AsyncMock(), history=(), status=503, message="Service Unavailable" + ), + aiohttp.ClientResponseError, + ), + ( + lambda: aiohttp.ClientResponseError( + request_info=AsyncMock(), history=(), status=500, message="Internal Server Error" + ), + aiohttp.ClientResponseError, + ), + ( + lambda: aiohttp.ClientConnectorError(AsyncMock(), OSError("Connection refused")), + aiohttp.ClientConnectorError, + ), + (lambda: TimeoutError("Request timeout"), TimeoutError), + ], + ids=[ + "aiohttp_503_exhausted", + "aiohttp_500_exhausted", + "connector_error_exhausted", + "timeout_exhausted", + ], + ) + @patch("airflow.providers.dbt.cloud.hooks.dbt.aiohttp.ClientSession.get") + async def test_get_job_details_retry_with_exhausted_retries( + self, get_mock, error_factory, expected_exception + ): + hook = DbtCloudHook(ACCOUNT_ID_CONN, retry_limit=2, retry_delay=0.1) + + def fail_cm(): + cm = AsyncMock() + cm.__aenter__.side_effect = error_factory() + return cm + + get_mock.side_effect = [fail_cm() for _ in range(2)] + + with pytest.raises(expected_exception): + await hook.get_job_details(run_id=RUN_ID, account_id=None) + + assert get_mock.call_count == 2 diff --git a/providers/dbt/cloud/tests/unit/dbt/cloud/triggers/test_dbt.py b/providers/dbt/cloud/tests/unit/dbt/cloud/triggers/test_dbt.py index 2a1f26b49d838..76818008a45af 100644 --- a/providers/dbt/cloud/tests/unit/dbt/cloud/triggers/test_dbt.py +++ b/providers/dbt/cloud/tests/unit/dbt/cloud/triggers/test_dbt.py @@ -45,6 +45,7 @@ def test_serialization(self): end_time=self.END_TIME, run_id=self.RUN_ID, account_id=self.ACCOUNT_ID, + hook_params={"retry_delay": 10}, ) classpath, kwargs = trigger.serialize() assert classpath == "airflow.providers.dbt.cloud.triggers.dbt.DbtCloudRunJobTrigger" @@ -54,6 +55,7 @@ def test_serialization(self): "conn_id": self.CONN_ID, "end_time": self.END_TIME, "poll_interval": self.POLL_INTERVAL, + "hook_params": {"retry_delay": 10}, } @pytest.mark.asyncio