diff --git a/providers/dbt/cloud/src/airflow/providers/dbt/cloud/sensors/dbt.py b/providers/dbt/cloud/src/airflow/providers/dbt/cloud/sensors/dbt.py index 0a7644a56a9c9..c2e9b9538b23f 100644 --- a/providers/dbt/cloud/src/airflow/providers/dbt/cloud/sensors/dbt.py +++ b/providers/dbt/cloud/src/airflow/providers/dbt/cloud/sensors/dbt.py @@ -55,6 +55,7 @@ def __init__( run_id: int, account_id: int | None = None, deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + hook_params: dict[str, Any] | None = None, **kwargs, ) -> None: if deferrable: @@ -68,13 +69,13 @@ def __init__( self.dbt_cloud_conn_id = dbt_cloud_conn_id self.run_id = run_id self.account_id = account_id - + self.hook_params = hook_params or {} self.deferrable = deferrable @cached_property def hook(self): """Returns DBT Cloud hook.""" - return DbtCloudHook(self.dbt_cloud_conn_id) + return DbtCloudHook(self.dbt_cloud_conn_id, **self.hook_params) def poke(self, context: Context) -> bool: job_run_status = self.hook.get_job_run_status(run_id=self.run_id, account_id=self.account_id) @@ -110,6 +111,7 @@ def execute(self, context: Context) -> None: account_id=self.account_id, poll_interval=self.poke_interval, end_time=end_time, + hook_params=self.hook_params, ), method_name="execute_complete", ) diff --git a/providers/dbt/cloud/tests/unit/dbt/cloud/sensors/test_dbt.py b/providers/dbt/cloud/tests/unit/dbt/cloud/sensors/test_dbt.py index 13be664c78415..a33f846bb2a53 100644 --- a/providers/dbt/cloud/tests/unit/dbt/cloud/sensors/test_dbt.py +++ b/providers/dbt/cloud/tests/unit/dbt/cloud/sensors/test_dbt.py @@ -58,6 +58,7 @@ def setup_class(self): account_id=ACCOUNT_ID, timeout=30, poke_interval=15, + hook_params={"retry_limit": 3, "retry_delay": 2.0}, ) def test_init(self): @@ -65,6 +66,7 @@ def test_init(self): assert self.sensor.run_id == RUN_ID assert self.sensor.timeout == 30 assert self.sensor.poke_interval == 15 + assert self.sensor.hook_params == {"retry_limit": 3, "retry_delay": 2.0} @pytest.mark.parametrize( argnames=("job_run_status", "expected_poke_result"),