Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from __future__ import annotations

import base64
import time
import uuid
import warnings
from datetime import timedelta
Expand Down Expand Up @@ -313,6 +314,80 @@ def get_sql_api_query_status(self, query_id: str) -> dict[str, str | list[str]]:
status_code, resp = self._make_api_call_with_retries("GET", url, header, params)
return self._process_response(status_code, resp)

def wait_for_query(
self, query_id: str, raise_error: bool = False, poll_interval: int = 5, timeout: int = 60
) -> dict[str, str | list[str]]:
"""
Wait for query to finish either successfully or with error.

:param query_id: statement handle id for the individual statement.
:param raise_error: whether to raise an error if the query failed.
:param poll_interval: time (in seconds) between checking the query status.
:param timeout: max time (in seconds) to wait for the query to finish before raising a TimeoutError.

:raises RuntimeError: If the query status is 'error' and `raise_error` is True.
:raises TimeoutError: If the query doesn't finish within the specified timeout.
"""
start_time = time.time()

while True:
response = self.get_sql_api_query_status(query_id=query_id)
self.log.debug("Query status `%s`", response["status"])

if time.time() - start_time > timeout:
raise TimeoutError(
f"Query `{query_id}` did not finish within the timeout period of {timeout} seconds."
)

if response["status"] != "running":
self.log.info("Query status `%s`", response["status"])
break

time.sleep(poll_interval)

if response["status"] == "error" and raise_error:
raise RuntimeError(response["message"])

return response

def get_result_from_successful_sql_api_query(self, query_id: str) -> list[dict[str, Any]]:
"""
Based on the query id HTTP requests are made to snowflake SQL API and return result data.

:param query_id: statement handle id for the individual statement.

:raises RuntimeError: If the query status is not 'success'.
"""
self.log.info("Retrieving data for query id %s", query_id)
header, params, url = self.get_request_url_header_params(query_id)
status_code, response = self._make_api_call_with_retries("GET", url, header, params)

if (query_status := self._process_response(status_code, response)["status"]) != "success":
msg = f"Query must have status `success` to retrieve data; got `{query_status}`."
raise RuntimeError(msg)

# Below fields should always be present in response, but added some safety checks
data = response.get("data", [])
if not data:
self.log.warning("No data found in the API response.")
return []
metadata = response.get("resultSetMetaData", {})
col_names = [row["name"] for row in metadata.get("rowType", [])]
if not col_names:
self.log.warning("No column metadata found in the API response.")
return []

num_partitions = len(metadata.get("partitionInfo", []))
if num_partitions > 1:
self.log.debug("Result data is returned as multiple partitions. Will perform additional queries.")
url += "?partition="
for partition_no in range(1, num_partitions): # First partition was already returned
self.log.debug("Querying for partition no. %s", partition_no)
_, response = self._make_api_call_with_retries("GET", url + str(partition_no), header, params)
data.extend(response.get("data", []))

return [dict(zip(col_names, row)) for row in data] # Merged column names with data

async def get_sql_api_query_status_async(self, query_id: str) -> dict[str, str | list[str]]:
"""
Based on the query id async HTTP request is made to snowflake SQL API and return response.
Expand Down
Loading