diff --git a/providers/snowflake/src/airflow/providers/snowflake/triggers/snowflake_trigger.py b/providers/snowflake/src/airflow/providers/snowflake/triggers/snowflake_trigger.py index a7aa8f3ca82d4..e460ea7840ce4 100644 --- a/providers/snowflake/src/airflow/providers/snowflake/triggers/snowflake_trigger.py +++ b/providers/snowflake/src/airflow/providers/snowflake/triggers/snowflake_trigger.py @@ -68,15 +68,16 @@ def serialize(self) -> tuple[str, dict[str, Any]]: async def run(self) -> AsyncIterator[TriggerEvent]: """Wait for the query the snowflake query to complete.""" - SnowflakeSqlApiHook( + hook = SnowflakeSqlApiHook( self.snowflake_conn_id, self.token_life_time, self.token_renewal_delta, ) + try: for query_id in self.query_ids: while True: - statement_status = await self.get_query_status(query_id) + statement_status = await self.get_query_status(query_id, hook) if statement_status["status"] not in ["running"]: break await asyncio.sleep(self.poll_interval) @@ -92,13 +93,17 @@ async def run(self) -> AsyncIterator[TriggerEvent]: except Exception as e: yield TriggerEvent({"status": "error", "message": str(e)}) - async def get_query_status(self, query_id: str) -> dict[str, Any]: + async def get_query_status( + self, query_id: str, hook: SnowflakeSqlApiHook | None = None + ) -> dict[str, Any]: """Return True if the SQL query is still running otherwise return False.""" - hook = SnowflakeSqlApiHook( - self.snowflake_conn_id, - self.token_life_time, - self.token_renewal_delta, - ) + if not hook: + hook = SnowflakeSqlApiHook( + self.snowflake_conn_id, + self.token_life_time, + self.token_renewal_delta, + ) + return await hook.get_sql_api_query_status_async(query_id) def _set_context(self, context):