Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def execute_query(
When executing the statement, Snowflake replaces placeholders (? and :name) in
the statement with these specified values.
"""
self.query_ids = []
conn_config = self._get_conn_params

req_id = uuid.uuid4()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import time
from collections.abc import Iterable, Mapping, Sequence
from datetime import timedelta
from functools import cached_property
from typing import TYPE_CHECKING, Any, SupportsAbs, cast

from airflow.configuration import conf
Expand Down Expand Up @@ -390,6 +391,7 @@ def __init__(
self.bindings = bindings
self.execute_async = False
self.deferrable = deferrable
self.query_ids: list[str] = []
if any([warehouse, database, role, schema, authenticator, session_parameters]): # pragma: no cover
hook_params = kwargs.pop("hook_params", {}) # pragma: no cover
kwargs["hook_params"] = {
Expand All @@ -403,20 +405,23 @@ def __init__(
}
super().__init__(conn_id=snowflake_conn_id, **kwargs) # pragma: no cover

@cached_property
def _hook(self):
return SnowflakeSqlApiHook(
snowflake_conn_id=self.snowflake_conn_id,
token_life_time=self.token_life_time,
token_renewal_delta=self.token_renewal_delta,
deferrable=self.deferrable,
**self.hook_params,
)

def execute(self, context: Context) -> None:
"""
Make a POST API request to snowflake by using SnowflakeSQL and execute the query to get the ids.

By deferring the SnowflakeSqlApiTrigger class passed along with query ids.
"""
self.log.info("Executing: %s", self.sql)
self._hook = SnowflakeSqlApiHook(
snowflake_conn_id=self.snowflake_conn_id,
token_life_time=self.token_life_time,
token_renewal_delta=self.token_renewal_delta,
deferrable=self.deferrable,
**self.hook_params,
)
self.query_ids = self._hook.execute_query(
self.sql, # type: ignore[arg-type]
statement_count=self.statement_count,
Expand Down Expand Up @@ -504,9 +509,11 @@ def execute_complete(self, context: Context, event: dict[str, str | list[str]] |
msg = f"{event['status']}: {event['message']}"
raise AirflowException(msg)
if "status" in event and event["status"] == "success":
hook = SnowflakeSqlApiHook(snowflake_conn_id=self.snowflake_conn_id)
query_ids = cast("list[str]", event["statement_query_ids"])
hook.check_query_output(query_ids)
self.query_ids = cast("list[str]", event["statement_query_ids"])
self._hook.check_query_output(self.query_ids)
self.log.info("%s completed successfully.", self.task_id)
# Re-assign query_ids to hook after coming back from deferral to be consistent for listeners.
if not self._hook.query_ids:
self._hook.query_ids = self.query_ids
else:
self.log.info("%s completed successfully.", self.task_id)
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
self.token_renewal_delta,
)
try:
statement_query_ids: list[str] = []
for query_id in self.query_ids:
while True:
statement_status = await self.get_query_status(query_id)
Expand All @@ -84,12 +83,10 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
if statement_status["status"] == "error":
yield TriggerEvent(statement_status)
return
if statement_status["status"] == "success":
statement_query_ids.extend(statement_status["statement_handles"])
yield TriggerEvent(
{
"status": "success",
"statement_query_ids": statement_query_ids,
"statement_query_ids": self.query_ids,
}
)
except Exception as e:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,46 @@ def test_execute_query(
query_ids = hook.execute_query(sql, statement_count)
assert query_ids == expected_query_ids

@mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.requests")
@mock.patch(
"airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook._get_conn_params",
new_callable=PropertyMock,
)
@mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook.get_headers")
def test_execute_query_multiple_times_give_fresh_query_ids_each_time(
self, mock_get_header, mock_conn_param, mock_requests
):
"""Test execute_query method, run query by mocking post request method and return the query ids"""
sql, statement_count, expected_response, expected_query_ids = (
SQL_MULTIPLE_STMTS,
4,
{"statementHandles": ["uuid2", "uuid3"]},
["uuid2", "uuid3"],
)

mock_requests.codes.ok = 200
mock_requests.post.side_effect = [
create_successful_response_mock(expected_response),
]
status_code_mock = mock.PropertyMock(return_value=200)
type(mock_requests.post.return_value).status_code = status_code_mock

hook = SnowflakeSqlApiHook("mock_conn_id")
query_ids = hook.execute_query(sql, statement_count)
assert query_ids == expected_query_ids

sql, statement_count, expected_response, expected_query_ids = (
SINGLE_STMT,
1,
{"statementHandle": "uuid"},
["uuid"],
)
mock_requests.post.side_effect = [
create_successful_response_mock(expected_response),
]
query_ids = hook.execute_query(sql, statement_count)
assert query_ids == expected_query_ids

@pytest.mark.parametrize(
"sql,statement_count,expected_response, expected_query_ids",
[(SINGLE_STMT, 1, {"statementHandle": "uuid"}, ["uuid"])],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,48 @@ def test_snowflake_sql_api_execute_complete(self, mock_conn, mock_event):
operator.execute_complete(context=None, event=mock_event)
mock_log_info.assert_called_with("%s completed successfully.", TASK_ID)

@pytest.mark.parametrize(
"mock_event",
[
None,
({"status": "success", "statement_query_ids": ["uuid", "uuid"]}),
],
)
@mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook.check_query_output")
def test_snowflake_sql_api_execute_complete_reassigns_query_ids(self, mock_conn, mock_event):
"""Tests execute_complete assert with successful message"""

operator = SnowflakeSqlApiOperator(
task_id=TASK_ID,
snowflake_conn_id=CONN_ID,
sql=SQL_MULTIPLE_STMTS,
statement_count=4,
deferrable=True,
)
expected_query_ids = mock_event["statement_query_ids"] if mock_event else []

assert operator.query_ids == []
assert operator._hook.query_ids == []

operator.execute_complete(context=None, event=mock_event)

assert operator.query_ids == expected_query_ids
assert operator._hook.query_ids == expected_query_ids

def test_snowflake_sql_api_caches_hook(self):
"""Tests execute_complete assert with successful message"""

operator = SnowflakeSqlApiOperator(
task_id=TASK_ID,
snowflake_conn_id=CONN_ID,
sql=SQL_MULTIPLE_STMTS,
statement_count=4,
deferrable=True,
)
hook1 = operator._hook
hook2 = operator._hook
assert hook1 is hook2

@mock.patch("airflow.providers.snowflake.operators.snowflake.SnowflakeSqlApiOperator.defer")
def test_snowflake_sql_api_execute_operator_failed_before_defer(
self, mock_defer, mock_execute_query, mock_get_sql_api_query_status
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,13 @@
LIFETIME = timedelta(minutes=59)
RENEWAL_DELTA = timedelta(minutes=54)
MODULE = "airflow.providers.snowflake"
QUERY_IDS = ["uuid"]


class TestSnowflakeSqlApiTrigger:
TRIGGER = SnowflakeSqlApiTrigger(
poll_interval=POLL_INTERVAL,
query_ids=["uuid"],
query_ids=QUERY_IDS,
snowflake_conn_id="test_conn",
token_life_time=LIFETIME,
token_renewal_delta=RENEWAL_DELTA,
Expand Down Expand Up @@ -82,8 +83,8 @@ async def test_snowflake_sql_trigger_completed(
Test SnowflakeSqlApiTrigger run method with success status and mock the get_sql_api_query_status
result and get_query_status to False.
"""
mock_get_query_status.return_value = {"status": "success", "statement_handles": ["uuid", "uuid1"]}
statement_query_ids = ["uuid", "uuid1"]
mock_get_query_status.return_value = {"status": "success", "statement_handles": statement_query_ids}
mock_get_sql_api_query_status_async.return_value = {
"message": "Statement executed successfully.",
"status": "success",
Expand All @@ -92,7 +93,7 @@ async def test_snowflake_sql_trigger_completed(

generator = self.TRIGGER.run()
actual = await generator.asend(None)
assert TriggerEvent({"status": "success", "statement_query_ids": statement_query_ids}) == actual
assert TriggerEvent({"status": "success", "statement_query_ids": QUERY_IDS}) == actual

@pytest.mark.asyncio
@mock.patch(f"{MODULE}.hooks.snowflake_sql_api.SnowflakeSqlApiHook.get_sql_api_query_status_async")
Expand Down