Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions providers/dbt/cloud/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ dependencies = [
"apache-airflow-providers-http",
"asgiref>=2.3.0",
"aiohttp>=3.9.2",
"tenacity>=8.3.0",
]

# The optional dependencies should be modified in place in the generated file
Expand Down
120 changes: 106 additions & 14 deletions providers/dbt/cloud/src/airflow/providers/dbt/cloud/hooks/dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from __future__ import annotations

import asyncio
import copy
import json
import time
import warnings
Expand All @@ -28,8 +29,10 @@

import aiohttp
from asgiref.sync import sync_to_async
from requests import exceptions as requests_exceptions
from requests.auth import AuthBase
from requests.sessions import Session
from tenacity import AsyncRetrying, RetryCallState, retry_if_exception, stop_after_attempt, wait_exponential

from airflow.exceptions import AirflowException
from airflow.providers.http.hooks.http import HttpHook
Expand Down Expand Up @@ -174,6 +177,10 @@ class DbtCloudHook(HttpHook):
Interact with dbt Cloud using the V2 (V3 if supported) API.

:param dbt_cloud_conn_id: The ID of the :ref:`dbt Cloud connection <howto/connection:dbt-cloud>`.
:param timeout_seconds: Optional. The timeout in seconds for HTTP requests. If not provided, no timeout is applied.
:param retry_limit: The number of times to retry a request in case of failure.
:param retry_delay: The delay in seconds between retries.
:param retry_args: A dictionary of arguments to pass to the `tenacity.retry` decorator.
"""

conn_name_attr = "dbt_cloud_conn_id"
Expand All @@ -193,9 +200,39 @@ def get_ui_field_behaviour(cls) -> dict[str, Any]:
},
}

def __init__(self, dbt_cloud_conn_id: str = default_conn_name, *args, **kwargs) -> None:
def __init__(
self,
dbt_cloud_conn_id: str = default_conn_name,
timeout_seconds: int | None = None,
retry_limit: int = 1,
retry_delay: float = 1.0,
retry_args: dict[Any, Any] | None = None,
) -> None:
super().__init__(auth_type=TokenAuth)
self.dbt_cloud_conn_id = dbt_cloud_conn_id
self.timeout_seconds = timeout_seconds
if retry_limit < 1:
raise ValueError("Retry limit must be greater than or equal to 1")
self.retry_limit = retry_limit
self.retry_delay = retry_delay

def retry_after_func(retry_state: RetryCallState) -> None:
error_msg = str(retry_state.outcome.exception()) if retry_state.outcome else "Unknown error"
self._log_request_error(retry_state.attempt_number, error_msg)

if retry_args:
self.retry_args = copy.copy(retry_args)
self.retry_args["retry"] = retry_if_exception(self._retryable_error)
self.retry_args["after"] = retry_after_func
self.retry_args["reraise"] = True
else:
self.retry_args = {
"stop": stop_after_attempt(self.retry_limit),
"wait": wait_exponential(min=self.retry_delay, max=(2**retry_limit)),
"retry": retry_if_exception(self._retryable_error),
"after": retry_after_func,
"reraise": True,
}

@staticmethod
def _get_tenant_domain(conn: Connection) -> str:
Expand Down Expand Up @@ -233,6 +270,36 @@ async def get_headers_tenants_from_connection(self) -> tuple[dict[str, Any], str
headers["Authorization"] = f"Token {self.connection.password}"
return headers, tenant

def _log_request_error(self, attempt_num: int, error: str) -> None:
self.log.error("Attempt %s API Request to DBT failed with reason: %s", attempt_num, error)

@staticmethod
def _retryable_error(exception: BaseException) -> bool:
if isinstance(exception, requests_exceptions.RequestException):
if isinstance(exception, (requests_exceptions.ConnectionError, requests_exceptions.Timeout)) or (
exception.response is not None
and (exception.response.status_code >= 500 or exception.response.status_code == 429)
):
return True

if isinstance(exception, aiohttp.ClientResponseError):
if exception.status >= 500 or exception.status == 429:
return True

if isinstance(exception, (aiohttp.ClientConnectorError, TimeoutError)):
return True

return False

def _a_get_retry_object(self) -> AsyncRetrying:
"""
Instantiate an async retry object.

:return: instance of AsyncRetrying class
"""
# for compatibility we use reraise to avoid handling request error
return AsyncRetrying(**self.retry_args)

@provide_account_id
async def get_job_details(
self, run_id: int, account_id: int | None = None, include_related: list[str] | None = None
Expand All @@ -249,17 +316,22 @@ async def get_job_details(
headers, tenant = await self.get_headers_tenants_from_connection()
url, params = self.get_request_url_params(tenant, endpoint, include_related)
proxies = self._get_proxies(self.connection) or {}
proxy = proxies.get("https") if proxies and url.startswith("https") else proxies.get("http")
extra_request_args = {}

async with aiohttp.ClientSession(headers=headers) as session:
proxy = proxies.get("https") if proxies and url.startswith("https") else proxies.get("http")
extra_request_args = {}
if proxy:
extra_request_args["proxy"] = proxy

if proxy:
extra_request_args["proxy"] = proxy
timeout = (
aiohttp.ClientTimeout(total=self.timeout_seconds) if self.timeout_seconds is not None else None
)

async with session.get(url, params=params, **extra_request_args) as response: # type: ignore[arg-type]
response.raise_for_status()
return await response.json()
async with aiohttp.ClientSession(headers=headers, timeout=timeout) as session:
async for attempt in self._a_get_retry_object():
with attempt:
async with session.get(url, params=params, **extra_request_args) as response: # type: ignore[arg-type]
response.raise_for_status()
return await response.json()

async def get_job_status(
self, run_id: int, account_id: int | None = None, include_related: list[str] | None = None
Expand Down Expand Up @@ -297,8 +369,14 @@ def get_conn(self, *args, **kwargs) -> Session:
def _paginate(
self, endpoint: str, payload: dict[str, Any] | None = None, proxies: dict[str, str] | None = None
) -> list[Response]:
extra_options = {"proxies": proxies} if proxies is not None else None
response = self.run(endpoint=endpoint, data=payload, extra_options=extra_options)
extra_options: dict[str, Any] = {}
if self.timeout_seconds is not None:
extra_options["timeout"] = self.timeout_seconds
if proxies is not None:
extra_options["proxies"] = proxies
response = self.run_with_advanced_retry(
_retry_args=self.retry_args, endpoint=endpoint, data=payload, extra_options=extra_options or None
)
resp_json = response.json()
limit = resp_json["extra"]["filters"]["limit"]
num_total_results = resp_json["extra"]["pagination"]["total_count"]
Expand All @@ -309,7 +387,12 @@ def _paginate(
_paginate_payload["offset"] = limit

while num_current_results < num_total_results:
response = self.run(endpoint=endpoint, data=_paginate_payload, extra_options=extra_options)
response = self.run_with_advanced_retry(
_retry_args=self.retry_args,
endpoint=endpoint,
data=_paginate_payload,
extra_options=extra_options,
)
resp_json = response.json()
results.append(response)
num_current_results += resp_json["extra"]["pagination"]["count"]
Expand All @@ -328,7 +411,11 @@ def _run_and_get_response(
self.method = method
full_endpoint = f"api/{api_version}/accounts/{endpoint}" if endpoint else None
proxies = self._get_proxies(self.connection)
extra_options = {"proxies": proxies} if proxies is not None else None
extra_options: dict[str, Any] = {}
if self.timeout_seconds is not None:
extra_options["timeout"] = self.timeout_seconds
if proxies is not None:
extra_options["proxies"] = proxies

if paginate:
if isinstance(payload, str):
Expand All @@ -339,7 +426,12 @@ def _run_and_get_response(

raise ValueError("An endpoint is needed to paginate a response.")

return self.run(endpoint=full_endpoint, data=payload, extra_options=extra_options)
return self.run_with_advanced_retry(
_retry_args=self.retry_args,
endpoint=full_endpoint,
data=payload,
extra_options=extra_options or None,
)

def list_accounts(self) -> list[Response]:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ class DbtCloudRunJobOperator(BaseOperator):
run. For more information on retry logic, see:
https://docs.getdbt.com/dbt-cloud/api-v2#/operations/Retry%20Failed%20Job
:param deferrable: Run operator in the deferrable mode
:param hook_params: Extra arguments passed to the DbtCloudHook constructor.
:return: The ID of the triggered dbt Cloud job run.
"""

Expand Down Expand Up @@ -124,6 +125,7 @@ def __init__(
reuse_existing_run: bool = False,
retry_from_failure: bool = False,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
hook_params: dict[str, Any] | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -144,6 +146,7 @@ def __init__(
self.reuse_existing_run = reuse_existing_run
self.retry_from_failure = retry_from_failure
self.deferrable = deferrable
self.hook_params = hook_params or {}

def execute(self, context: Context):
if self.trigger_reason is None:
Expand Down Expand Up @@ -273,7 +276,7 @@ def on_kill(self) -> None:
@cached_property
def hook(self):
"""Returns DBT Cloud hook."""
return DbtCloudHook(self.dbt_cloud_conn_id)
return DbtCloudHook(self.dbt_cloud_conn_id, **self.hook_params)

def get_openlineage_facets_on_complete(self, task_instance) -> OperatorLineage:
"""
Expand Down Expand Up @@ -311,6 +314,7 @@ class DbtCloudGetJobRunArtifactOperator(BaseOperator):
be returned.
:param output_file_name: Optional. The desired file name for the download artifact file.
Defaults to <run_id>_<path> (e.g. "728368_run_results.json").
:param hook_params: Extra arguments passed to the DbtCloudHook constructor.
"""

template_fields = ("dbt_cloud_conn_id", "run_id", "path", "account_id", "output_file_name")
Expand All @@ -324,6 +328,7 @@ def __init__(
account_id: int | None = None,
step: int | None = None,
output_file_name: str | None = None,
hook_params: dict[str, Any] | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -333,9 +338,10 @@ def __init__(
self.account_id = account_id
self.step = step
self.output_file_name = output_file_name or f"{self.run_id}_{self.path}".replace("/", "-")
self.hook_params = hook_params or {}

def execute(self, context: Context) -> str:
hook = DbtCloudHook(self.dbt_cloud_conn_id)
hook = DbtCloudHook(self.dbt_cloud_conn_id, **self.hook_params)
response = hook.get_job_run_artifact(
run_id=self.run_id, path=self.path, account_id=self.account_id, step=self.step
)
Expand Down Expand Up @@ -370,6 +376,7 @@ class DbtCloudListJobsOperator(BaseOperator):
:param order_by: Optional. Field to order the result by. Use '-' to indicate reverse order.
For example, to use reverse order by the run ID use ``order_by=-id``.
:param project_id: Optional. The ID of a dbt Cloud project.
:param hook_params: Extra arguments passed to the DbtCloudHook constructor.
"""

template_fields = (
Expand All @@ -384,16 +391,18 @@ def __init__(
account_id: int | None = None,
project_id: int | None = None,
order_by: str | None = None,
hook_params: dict[str, Any] | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.dbt_cloud_conn_id = dbt_cloud_conn_id
self.account_id = account_id
self.project_id = project_id
self.order_by = order_by
self.hook_params = hook_params or {}

def execute(self, context: Context) -> list:
hook = DbtCloudHook(self.dbt_cloud_conn_id)
hook = DbtCloudHook(self.dbt_cloud_conn_id, **self.hook_params)
list_jobs_response = hook.list_jobs(
account_id=self.account_id, order_by=self.order_by, project_id=self.project_id
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class DbtCloudRunJobTrigger(BaseTrigger):
:param end_time: Time in seconds to wait for a job run to reach a terminal status. Defaults to 7 days.
:param account_id: The ID of a dbt Cloud account.
:param poll_interval: polling period in seconds to check for the status.
:param hook_params: Extra arguments passed to the DbtCloudHook constructor.
"""

def __init__(
Expand All @@ -45,13 +46,15 @@ def __init__(
end_time: float,
poll_interval: float,
account_id: int | None,
hook_params: dict[str, Any] | None = None,
):
super().__init__()
self.run_id = run_id
self.account_id = account_id
self.conn_id = conn_id
self.end_time = end_time
self.poll_interval = poll_interval
self.hook_params = hook_params or {}

def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serialize DbtCloudRunJobTrigger arguments and classpath."""
Expand All @@ -63,12 +66,13 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"conn_id": self.conn_id,
"end_time": self.end_time,
"poll_interval": self.poll_interval,
"hook_params": self.hook_params,
},
)

async def run(self) -> AsyncIterator[TriggerEvent]:
"""Make async connection to Dbt, polls for the pipeline run status."""
hook = DbtCloudHook(self.conn_id)
hook = DbtCloudHook(self.conn_id, **self.hook_params)
try:
while await self.is_still_running(hook):
if self.end_time < time.time():
Expand Down
Loading
Loading