From be17babcf092c6298ce524d1985446222ece8935 Mon Sep 17 00:00:00 2001 From: Kacper Muda Date: Mon, 9 Jun 2025 18:39:07 +0200 Subject: [PATCH] fix: make query_ids in SnowflakeSqlApiOperator in deferrable mode consistent --- .../snowflake/hooks/snowflake_sql_api.py | 1 + .../snowflake/operators/snowflake.py | 27 +++++++----- .../snowflake/triggers/snowflake_trigger.py | 5 +-- .../snowflake/hooks/test_snowflake_sql_api.py | 40 ++++++++++++++++++ .../snowflake/operators/test_snowflake.py | 42 +++++++++++++++++++ .../unit/snowflake/triggers/test_snowflake.py | 7 ++-- 6 files changed, 105 insertions(+), 17 deletions(-) diff --git a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py index c7b9765c60968..0a2c9fd424776 100644 --- a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py +++ b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py @@ -137,6 +137,7 @@ def execute_query( When executing the statement, Snowflake replaces placeholders (? and :name) in the statement with these specified values. """ + self.query_ids = [] conn_config = self._get_conn_params req_id = uuid.uuid4() diff --git a/providers/snowflake/src/airflow/providers/snowflake/operators/snowflake.py b/providers/snowflake/src/airflow/providers/snowflake/operators/snowflake.py index 1018871654030..2c8db1391e82b 100644 --- a/providers/snowflake/src/airflow/providers/snowflake/operators/snowflake.py +++ b/providers/snowflake/src/airflow/providers/snowflake/operators/snowflake.py @@ -20,6 +20,7 @@ import time from collections.abc import Iterable, Mapping, Sequence from datetime import timedelta +from functools import cached_property from typing import TYPE_CHECKING, Any, SupportsAbs, cast from airflow.configuration import conf @@ -390,6 +391,7 @@ def __init__( self.bindings = bindings self.execute_async = False self.deferrable = deferrable + self.query_ids: list[str] = [] if any([warehouse, database, role, schema, authenticator, session_parameters]): # pragma: no cover hook_params = kwargs.pop("hook_params", {}) # pragma: no cover kwargs["hook_params"] = { @@ -403,6 +405,16 @@ def __init__( } super().__init__(conn_id=snowflake_conn_id, **kwargs) # pragma: no cover + @cached_property + def _hook(self): + return SnowflakeSqlApiHook( + snowflake_conn_id=self.snowflake_conn_id, + token_life_time=self.token_life_time, + token_renewal_delta=self.token_renewal_delta, + deferrable=self.deferrable, + **self.hook_params, + ) + def execute(self, context: Context) -> None: """ Make a POST API request to snowflake by using SnowflakeSQL and execute the query to get the ids. @@ -410,13 +422,6 @@ def execute(self, context: Context) -> None: By deferring the SnowflakeSqlApiTrigger class passed along with query ids. """ self.log.info("Executing: %s", self.sql) - self._hook = SnowflakeSqlApiHook( - snowflake_conn_id=self.snowflake_conn_id, - token_life_time=self.token_life_time, - token_renewal_delta=self.token_renewal_delta, - deferrable=self.deferrable, - **self.hook_params, - ) self.query_ids = self._hook.execute_query( self.sql, # type: ignore[arg-type] statement_count=self.statement_count, @@ -504,9 +509,11 @@ def execute_complete(self, context: Context, event: dict[str, str | list[str]] | msg = f"{event['status']}: {event['message']}" raise AirflowException(msg) if "status" in event and event["status"] == "success": - hook = SnowflakeSqlApiHook(snowflake_conn_id=self.snowflake_conn_id) - query_ids = cast("list[str]", event["statement_query_ids"]) - hook.check_query_output(query_ids) + self.query_ids = cast("list[str]", event["statement_query_ids"]) + self._hook.check_query_output(self.query_ids) self.log.info("%s completed successfully.", self.task_id) + # Re-assign query_ids to hook after coming back from deferral to be consistent for listeners. + if not self._hook.query_ids: + self._hook.query_ids = self.query_ids else: self.log.info("%s completed successfully.", self.task_id) diff --git a/providers/snowflake/src/airflow/providers/snowflake/triggers/snowflake_trigger.py b/providers/snowflake/src/airflow/providers/snowflake/triggers/snowflake_trigger.py index b425b9a6250d3..a7aa8f3ca82d4 100644 --- a/providers/snowflake/src/airflow/providers/snowflake/triggers/snowflake_trigger.py +++ b/providers/snowflake/src/airflow/providers/snowflake/triggers/snowflake_trigger.py @@ -74,7 +74,6 @@ async def run(self) -> AsyncIterator[TriggerEvent]: self.token_renewal_delta, ) try: - statement_query_ids: list[str] = [] for query_id in self.query_ids: while True: statement_status = await self.get_query_status(query_id) @@ -84,12 +83,10 @@ async def run(self) -> AsyncIterator[TriggerEvent]: if statement_status["status"] == "error": yield TriggerEvent(statement_status) return - if statement_status["status"] == "success": - statement_query_ids.extend(statement_status["statement_handles"]) yield TriggerEvent( { "status": "success", - "statement_query_ids": statement_query_ids, + "statement_query_ids": self.query_ids, } ) except Exception as e: diff --git a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake_sql_api.py b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake_sql_api.py index 91fd14777bb5d..21a3fa7a999a8 100644 --- a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake_sql_api.py +++ b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake_sql_api.py @@ -203,6 +203,46 @@ def test_execute_query( query_ids = hook.execute_query(sql, statement_count) assert query_ids == expected_query_ids + @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.requests") + @mock.patch( + "airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook._get_conn_params", + new_callable=PropertyMock, + ) + @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook.get_headers") + def test_execute_query_multiple_times_give_fresh_query_ids_each_time( + self, mock_get_header, mock_conn_param, mock_requests + ): + """Test execute_query method, run query by mocking post request method and return the query ids""" + sql, statement_count, expected_response, expected_query_ids = ( + SQL_MULTIPLE_STMTS, + 4, + {"statementHandles": ["uuid2", "uuid3"]}, + ["uuid2", "uuid3"], + ) + + mock_requests.codes.ok = 200 + mock_requests.post.side_effect = [ + create_successful_response_mock(expected_response), + ] + status_code_mock = mock.PropertyMock(return_value=200) + type(mock_requests.post.return_value).status_code = status_code_mock + + hook = SnowflakeSqlApiHook("mock_conn_id") + query_ids = hook.execute_query(sql, statement_count) + assert query_ids == expected_query_ids + + sql, statement_count, expected_response, expected_query_ids = ( + SINGLE_STMT, + 1, + {"statementHandle": "uuid"}, + ["uuid"], + ) + mock_requests.post.side_effect = [ + create_successful_response_mock(expected_response), + ] + query_ids = hook.execute_query(sql, statement_count) + assert query_ids == expected_query_ids + @pytest.mark.parametrize( "sql,statement_count,expected_response, expected_query_ids", [(SINGLE_STMT, 1, {"statementHandle": "uuid"}, ["uuid"])], diff --git a/providers/snowflake/tests/unit/snowflake/operators/test_snowflake.py b/providers/snowflake/tests/unit/snowflake/operators/test_snowflake.py index 1651818ed54bf..b61b699774e78 100644 --- a/providers/snowflake/tests/unit/snowflake/operators/test_snowflake.py +++ b/providers/snowflake/tests/unit/snowflake/operators/test_snowflake.py @@ -332,6 +332,48 @@ def test_snowflake_sql_api_execute_complete(self, mock_conn, mock_event): operator.execute_complete(context=None, event=mock_event) mock_log_info.assert_called_with("%s completed successfully.", TASK_ID) + @pytest.mark.parametrize( + "mock_event", + [ + None, + ({"status": "success", "statement_query_ids": ["uuid", "uuid"]}), + ], + ) + @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook.check_query_output") + def test_snowflake_sql_api_execute_complete_reassigns_query_ids(self, mock_conn, mock_event): + """Tests execute_complete assert with successful message""" + + operator = SnowflakeSqlApiOperator( + task_id=TASK_ID, + snowflake_conn_id=CONN_ID, + sql=SQL_MULTIPLE_STMTS, + statement_count=4, + deferrable=True, + ) + expected_query_ids = mock_event["statement_query_ids"] if mock_event else [] + + assert operator.query_ids == [] + assert operator._hook.query_ids == [] + + operator.execute_complete(context=None, event=mock_event) + + assert operator.query_ids == expected_query_ids + assert operator._hook.query_ids == expected_query_ids + + def test_snowflake_sql_api_caches_hook(self): + """Tests execute_complete assert with successful message""" + + operator = SnowflakeSqlApiOperator( + task_id=TASK_ID, + snowflake_conn_id=CONN_ID, + sql=SQL_MULTIPLE_STMTS, + statement_count=4, + deferrable=True, + ) + hook1 = operator._hook + hook2 = operator._hook + assert hook1 is hook2 + @mock.patch("airflow.providers.snowflake.operators.snowflake.SnowflakeSqlApiOperator.defer") def test_snowflake_sql_api_execute_operator_failed_before_defer( self, mock_defer, mock_execute_query, mock_get_sql_api_query_status diff --git a/providers/snowflake/tests/unit/snowflake/triggers/test_snowflake.py b/providers/snowflake/tests/unit/snowflake/triggers/test_snowflake.py index 9fc14591625e2..42a3a07d224de 100644 --- a/providers/snowflake/tests/unit/snowflake/triggers/test_snowflake.py +++ b/providers/snowflake/tests/unit/snowflake/triggers/test_snowflake.py @@ -30,12 +30,13 @@ LIFETIME = timedelta(minutes=59) RENEWAL_DELTA = timedelta(minutes=54) MODULE = "airflow.providers.snowflake" +QUERY_IDS = ["uuid"] class TestSnowflakeSqlApiTrigger: TRIGGER = SnowflakeSqlApiTrigger( poll_interval=POLL_INTERVAL, - query_ids=["uuid"], + query_ids=QUERY_IDS, snowflake_conn_id="test_conn", token_life_time=LIFETIME, token_renewal_delta=RENEWAL_DELTA, @@ -82,8 +83,8 @@ async def test_snowflake_sql_trigger_completed( Test SnowflakeSqlApiTrigger run method with success status and mock the get_sql_api_query_status result and get_query_status to False. """ - mock_get_query_status.return_value = {"status": "success", "statement_handles": ["uuid", "uuid1"]} statement_query_ids = ["uuid", "uuid1"] + mock_get_query_status.return_value = {"status": "success", "statement_handles": statement_query_ids} mock_get_sql_api_query_status_async.return_value = { "message": "Statement executed successfully.", "status": "success", @@ -92,7 +93,7 @@ async def test_snowflake_sql_trigger_completed( generator = self.TRIGGER.run() actual = await generator.asend(None) - assert TriggerEvent({"status": "success", "statement_query_ids": statement_query_ids}) == actual + assert TriggerEvent({"status": "success", "statement_query_ids": QUERY_IDS}) == actual @pytest.mark.asyncio @mock.patch(f"{MODULE}.hooks.snowflake_sql_api.SnowflakeSqlApiHook.get_sql_api_query_status_async")