Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -455,16 +455,21 @@ def execute(self, context: Context) -> None:
method_name="execute_complete",
)
else:
statement_status = self.poll_on_queries()
if statement_status["error"]:
raise AirflowException(statement_status["error"])
while True:
statement_status = self.poll_on_queries()
if statement_status["error"]:
raise AirflowException(statement_status["error"])
if not statement_status["running"]:
break

self._hook.check_query_output(self.query_ids)

def poll_on_queries(self):
"""Poll on requested queries."""
queries_in_progress = set(self.query_ids)
statement_success_status = {}
statement_error_status = {}
statement_running_status = {}
for query_id in self.query_ids:
if not len(queries_in_progress):
break
Expand All @@ -479,8 +484,14 @@ def poll_on_queries(self):
if statement_status.get("status") == "success":
statement_success_status[query_id] = statement_status
queries_in_progress.remove(query_id)
if statement_status.get("status") == "running":
statement_running_status[query_id] = statement_status
time.sleep(self.poll_interval)
return {"success": statement_success_status, "error": statement_error_status}
return {
"success": statement_success_status,
"error": statement_error_status,
"running": statement_running_status,
}

def execute_complete(self, context: Context, event: dict[str, str | list[str]] | None = None) -> None:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -391,3 +391,66 @@ def test_snowflake_sql_api_execute_operator_running_before_defer(
operator.execute(create_context(operator))

assert mock_defer.called

def test_snowflake_sql_api_execute_operator_polling_running(
self, mock_execute_query, mock_get_sql_api_query_status, mock_check_query_output
):
"""
Tests that the execute method correctly loops and waits until all queries complete
when ``deferrable=False``
"""
operator = SnowflakeSqlApiOperator(
task_id=TASK_ID,
snowflake_conn_id=CONN_ID,
sql=SQL_MULTIPLE_STMTS,
statement_count=4,
do_xcom_push=False,
deferrable=False,
)

mock_execute_query.return_value = ["uuid1"]

mock_get_sql_api_query_status.side_effect = [
# Initial get_sql_api_query_status check
{"status": "running"},
# 1st poll_on_queries check (poll_interval: 5s)
{"status": "running"},
# 2nd poll_on_queries check (poll_interval: 5s)
{"status": "running"},
# 3rd poll_on_queries check (poll_interval: 5s)
{"status": "success"},
]

with mock.patch("time.sleep") as mock_sleep:
operator.execute(context=None)
mock_check_query_output.assert_called_once_with(["uuid1"])
assert mock_sleep.call_count == 3

def test_snowflake_sql_api_execute_operator_polling_failed(
self, mock_execute_query, mock_get_sql_api_query_status, mock_check_query_output
):
"""
Tests that the execute method raises AirflowException if any query fails during polling
when ``deferrable=False``
"""
operator = SnowflakeSqlApiOperator(
task_id=TASK_ID,
snowflake_conn_id=CONN_ID,
sql=SQL_MULTIPLE_STMTS,
statement_count=4,
do_xcom_push=False,
deferrable=False,
)

mock_execute_query.return_value = ["uuid1"]

mock_get_sql_api_query_status.side_effect = [
# Initial get_sql_api_query_status check
{"status": "running"},
# 1st poll_on_queries check
{"status": "error"},
]

with pytest.raises(AirflowException):
operator.execute(context=None)
mock_check_query_output.assert_not_called()