Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,16 @@ def serialize(self) -> tuple[str, dict[str, Any]]:

async def run(self) -> AsyncIterator[TriggerEvent]:
"""Wait for the query the snowflake query to complete."""
SnowflakeSqlApiHook(
hook = SnowflakeSqlApiHook(
self.snowflake_conn_id,
self.token_life_time,
self.token_renewal_delta,
)

try:
for query_id in self.query_ids:
while True:
statement_status = await self.get_query_status(query_id)
statement_status = await self.get_query_status(query_id, hook)
if statement_status["status"] not in ["running"]:
break
await asyncio.sleep(self.poll_interval)
Expand All @@ -92,13 +93,17 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
except Exception as e:
yield TriggerEvent({"status": "error", "message": str(e)})

async def get_query_status(self, query_id: str) -> dict[str, Any]:
async def get_query_status(
self, query_id: str, hook: SnowflakeSqlApiHook | None = None
) -> dict[str, Any]:
"""Return True if the SQL query is still running otherwise return False."""
hook = SnowflakeSqlApiHook(
self.snowflake_conn_id,
self.token_life_time,
self.token_renewal_delta,
)
if not hook:
hook = SnowflakeSqlApiHook(
self.snowflake_conn_id,
self.token_life_time,
self.token_renewal_delta,
)

return await hook.get_sql_api_query_status_async(query_id)

def _set_context(self, context):
Expand Down