Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,18 @@

import aiohttp
import requests
from aiohttp import ClientConnectionError, ClientResponseError
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from requests.exceptions import ConnectionError, HTTPError, Timeout
from tenacity import (
AsyncRetrying,
Retrying,
before_sleep_log,
retry_if_exception,
stop_after_attempt,
wait_exponential,
)

from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook
Expand Down Expand Up @@ -65,6 +75,7 @@ class SnowflakeSqlApiHook(SnowflakeHook):
:param token_life_time: lifetime of the JWT Token in timedelta
:param token_renewal_delta: Renewal time of the JWT Token in timedelta
:param deferrable: Run operator in the deferrable mode.
:param api_retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` & ``tenacity.AsyncRetrying`` classes.
"""

LIFETIME = timedelta(minutes=59) # The tokens will have a 59 minute lifetime
Expand All @@ -75,15 +86,27 @@ def __init__(
snowflake_conn_id: str,
token_life_time: timedelta = LIFETIME,
token_renewal_delta: timedelta = RENEWAL_DELTA,
api_retry_args: dict[Any, Any] | None = None, # Optional retry arguments passed to tenacity.retry
*args: Any,
**kwargs: Any,
):
self.snowflake_conn_id = snowflake_conn_id
self.token_life_time = token_life_time
self.token_renewal_delta = token_renewal_delta

super().__init__(snowflake_conn_id, *args, **kwargs)
self.private_key: Any = None

self.retry_config = {
"retry": retry_if_exception(self._should_retry_on_error),
"wait": wait_exponential(multiplier=1, min=1, max=60),
"stop": stop_after_attempt(5),
"before_sleep": before_sleep_log(self.log, log_level=20), # INFO level
"reraise": True,
}
if api_retry_args:
self.retry_config.update(api_retry_args)

def get_private_key(self) -> None:
"""Get the private key from snowflake connection."""
conn = self.get_connection(self.snowflake_conn_id)
Expand Down Expand Up @@ -168,13 +191,8 @@ def execute_query(
"query_tag": query_tag,
},
}
response = requests.post(url, json=data, headers=headers, params=params)
try:
response.raise_for_status()
except requests.exceptions.HTTPError as e: # pragma: no cover
msg = f"Response: {e.response.content.decode()} Status Code: {e.response.status_code}"
raise AirflowException(msg)
json_response = response.json()

_, json_response = self._make_api_call_with_retries("POST", url, headers, params, data)
self.log.info("Snowflake SQL POST API response: %s", json_response)
if "statementHandles" in json_response:
self.query_ids = json_response["statementHandles"]
Expand Down Expand Up @@ -259,13 +277,10 @@ def check_query_output(self, query_ids: list[str]) -> None:
"""
for query_id in query_ids:
header, params, url = self.get_request_url_header_params(query_id)
try:
response = requests.get(url, headers=header, params=params)
response.raise_for_status()
self.log.info(response.json())
except requests.exceptions.HTTPError as e:
msg = f"Response: {e.response.content.decode()}, Status Code: {e.response.status_code}"
raise AirflowException(msg)
_, response_json = self._make_api_call_with_retries(
method="GET", url=url, headers=header, params=params
)
self.log.info(response_json)

def _process_response(self, status_code, resp):
self.log.info("Snowflake SQL GET statements status API response: %s", resp)
Expand Down Expand Up @@ -295,9 +310,7 @@ def get_sql_api_query_status(self, query_id: str) -> dict[str, str | list[str]]:
"""
self.log.info("Retrieving status for query id %s", query_id)
header, params, url = self.get_request_url_header_params(query_id)
response = requests.get(url, params=params, headers=header)
status_code = response.status_code
resp = response.json()
status_code, resp = self._make_api_call_with_retries("GET", url, header, params)
return self._process_response(status_code, resp)

async def get_sql_api_query_status_async(self, query_id: str) -> dict[str, str | list[str]]:
Expand All @@ -308,10 +321,85 @@ async def get_sql_api_query_status_async(self, query_id: str) -> dict[str, str |
"""
self.log.info("Retrieving status for query id %s", query_id)
header, params, url = self.get_request_url_header_params(query_id)
async with (
aiohttp.ClientSession(headers=header) as session,
session.get(url, params=params) as response,
status_code, resp = await self._make_api_call_with_retries_async("GET", url, header, params)
return self._process_response(status_code, resp)

@staticmethod
def _should_retry_on_error(exception) -> bool:
"""
Determine if the exception should trigger a retry based on error type and status code.

Retries on HTTP errors 429 (Too Many Requests), 503 (Service Unavailable),
and 504 (Gateway Timeout) as recommended by Snowflake error handling docs.
Retries on connection errors and timeouts.

:param exception: The exception to check
:return: True if the request should be retried, False otherwise
"""
if isinstance(exception, HTTPError):
return exception.response.status_code in [429, 503, 504]
if isinstance(exception, ClientResponseError):
return exception.status in [429, 503, 504]
if isinstance(
exception,
(
ConnectionError,
Timeout,
ClientConnectionError,
),
):
status_code = response.status
resp = await response.json()
return self._process_response(status_code, resp)
return True
return False

def _make_api_call_with_retries(
self, method: str, url: str, headers: dict, params: dict | None = None, json: dict | None = None
):
"""
Make an API call to the Snowflake SQL API with retry logic for specific HTTP errors.

Error handling implemented based on Snowflake error handling docs:
https://docs.snowflake.com/en/developer-guide/sql-api/handling-errors

:param method: The HTTP method to use for the API call.
:param url: The URL for the API endpoint.
:param headers: The headers to include in the API call.
:param params: (Optional) The query parameters to include in the API call.
:param data: (Optional) The data to include in the API call.
:return: The response object from the API call.
"""
with requests.Session() as session:
for attempt in Retrying(**self.retry_config): # type: ignore
with attempt:
if method.upper() in ("GET", "POST"):
response = session.request(
method=method.lower(), url=url, headers=headers, params=params, json=json
)
else:
raise ValueError(f"Unsupported HTTP method: {method}")
response.raise_for_status()
return response.status_code, response.json()

async def _make_api_call_with_retries_async(self, method, url, headers, params=None):
"""
Make an API call to the Snowflake SQL API asynchronously with retry logic for specific HTTP errors.

Error handling implemented based on Snowflake error handling docs:
https://docs.snowflake.com/en/developer-guide/sql-api/handling-errors

:param method: The HTTP method to use for the API call. Only GET is supported as is synchronous.
:param url: The URL for the API endpoint.
:param headers: The headers to include in the API call.
:param params: (Optional) The query parameters to include in the API call.
:return: The response object from the API call.
"""
async with aiohttp.ClientSession(headers=headers) as session:
async for attempt in AsyncRetrying(**self.retry_config): # type: ignore
with attempt:
if method.upper() == "GET":
async with session.request(method=method.lower(), url=url, params=params) as response:
response.raise_for_status()
# Return status and json content for async processing
content = await response.json()
return response.status, content
else:
raise ValueError(f"Unsupported HTTP method: {method}")
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,7 @@ class SnowflakeSqlApiOperator(SQLExecuteQueryOperator):
When executing the statement, Snowflake replaces placeholders (? and :name) in
the statement with these specified values.
:param deferrable: Run operator in the deferrable mode.
:param snowflake_api_retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` & ``tenacity.AsyncRetrying`` classes.
"""

LIFETIME = timedelta(minutes=59) # The tokens will have a 59 minutes lifetime
Expand All @@ -381,6 +382,7 @@ def __init__(
token_renewal_delta: timedelta = RENEWAL_DELTA,
bindings: dict[str, Any] | None = None,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
snowflake_api_retry_args: dict[str, Any] | None = None,
**kwargs: Any,
) -> None:
self.snowflake_conn_id = snowflake_conn_id
Expand All @@ -390,6 +392,7 @@ def __init__(
self.token_renewal_delta = token_renewal_delta
self.bindings = bindings
self.execute_async = False
self.snowflake_api_retry_args = snowflake_api_retry_args or {}
self.deferrable = deferrable
self.query_ids: list[str] = []
if any([warehouse, database, role, schema, authenticator, session_parameters]): # pragma: no cover
Expand All @@ -412,6 +415,7 @@ def _hook(self):
token_life_time=self.token_life_time,
token_renewal_delta=self.token_renewal_delta,
deferrable=self.deferrable,
api_retry_args=self.snowflake_api_retry_args,
**self.hook_params,
)

Expand Down
Loading