Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import os
from collections.abc import Callable, Iterable, Mapping
from contextlib import closing, contextmanager
from datetime import datetime, timedelta
from functools import cached_property
from io import StringIO
from pathlib import Path
Expand All @@ -41,9 +42,14 @@
from airflow.providers.common.sql.hooks.handlers import return_single_query_results
from airflow.providers.common.sql.hooks.sql import DbApiHook
from airflow.providers.snowflake.utils.openlineage import fix_snowflake_sqlalchemy_uri
from airflow.utils import timezone
from airflow.utils.strings import to_boolean

OAUTH_REQUEST_TIMEOUT = 30 # seconds, avoid hanging tasks on token request
OAUTH_EXPIRY_BUFFER = 30
T = TypeVar("T")


if TYPE_CHECKING:
from airflow.providers.openlineage.extractors import OperatorLineage
from airflow.providers.openlineage.sqlparser import DatabaseInfo
Expand Down Expand Up @@ -173,6 +179,11 @@ def __init__(self, *args, **kwargs) -> None:
self.client_store_temporary_credential = kwargs.pop("client_store_temporary_credential", None)
self.query_ids: list[str] = []

# Access token and expiration timestamp persisted
# to handle premature expiry.
self._oauth_token: str | None = None
self._oauth_token_expires_at: datetime | None = None

def _get_field(self, extra_dict, field_name):
backcompat_prefix = "extra__snowflake__"
backcompat_key = f"{backcompat_prefix}{field_name}"
Expand All @@ -198,7 +209,7 @@ def _get_field(self, extra_dict, field_name):
@property
def account_identifier(self) -> str:
"""Get snowflake account identifier."""
conn_config = self._get_conn_params
conn_config = self._get_conn_params()
account_identifier = f"https://{conn_config['account']}"

if conn_config["region"]:
Expand All @@ -214,46 +225,15 @@ def get_oauth_token(
) -> str:
"""Generate temporary OAuth access token using refresh token in connection details."""
if conn_config is None:
conn_config = self._get_conn_params
conn_config = self._get_static_conn_params

url = token_endpoint or f"https://{conn_config['account']}.snowflakecomputing.com/oauth/token-request"
if token_endpoint is None:
token_endpoint = conn_config.get("token_endpoint")

data = {
"grant_type": grant_type,
"redirect_uri": conn_config.get("redirect_uri", "https://localhost.com"),
}

scope = conn_config.get("scope")

if scope:
data["scope"] = scope

if grant_type == "refresh_token":
data |= {
"refresh_token": conn_config["refresh_token"],
}
elif grant_type == "client_credentials":
pass # no setup necessary for client credentials grant.
else:
raise ValueError(f"Unknown grant_type: {grant_type}")

response = requests.post(
url,
data=data,
headers={
"Content-Type": "application/x-www-form-urlencoded",
},
auth=HTTPBasicAuth(conn_config["client_id"], conn_config["client_secret"]), # type: ignore[arg-type]
return self._get_valid_oauth_token(
conn_config=conn_config, token_endpoint=token_endpoint, grant_type=grant_type
)

try:
response.raise_for_status()
except requests.exceptions.HTTPError as e: # pragma: no cover
msg = f"Response: {e.response.content.decode()} Status Code: {e.response.status_code}"
raise AirflowException(msg)
token = response.json()["access_token"]
return token

def get_azure_oauth_token(self, azure_conn_id: str) -> str:
"""
Generate OAuth access token using Azure connection id.
Expand Down Expand Up @@ -286,12 +266,42 @@ def get_azure_oauth_token(self, azure_conn_id: str) -> str:
token = azure_base_hook.get_token(scope).token
return token

@cached_property
def _get_conn_params(self) -> dict[str, str | None]:
"""
Fetch connection params as a dict.

This is used in ``get_uri()`` and ``get_connection()``.
This is used in ``get_uri()`` and ``get_conn()``.
"""
static_config = self._get_static_conn_params
conn_config = dict(static_config)

if conn_config.get("authenticator") == "oauth":
azure_conn_id = conn_config.get("azure_conn_id")
if azure_conn_id:
conn_config["token"] = self.get_azure_oauth_token(azure_conn_id)
else:
grant_type = conn_config.get("grant_type")
if not grant_type:
raise ValueError("Grant_type not provided")
conn_config["token"] = self._get_valid_oauth_token(
conn_config=conn_config,
token_endpoint=conn_config.get("token_endpoint"),
grant_type=grant_type,
)

conn_config.pop("login", None)
conn_config.pop("user", None)
conn_config.pop("password", None)
return conn_config

@cached_property
def _get_static_conn_params(self) -> dict[str, str | None]:
"""
Return static Snowflake connection parameters.

These parameters are cached for the lifetime of the hook and exclude
time-sensitive values such as OAuth access tokens. This is used in
``_get_valid_oauth_token()`` and ``get_conn_params()``.
"""
conn = self.get_connection(self.get_conn_id())
extra_dict = conn.extra_dejson
Expand Down Expand Up @@ -388,25 +398,21 @@ def _get_conn_params(self) -> dict[str, str | None]:
conn_config["refresh_token"] = refresh_token
conn_config["authenticator"] = "oauth"

grant_type = self._get_field(extra_dict, "grant_type") or ""
if grant_type:
conn_config["grant_type"] = grant_type
elif refresh_token:
conn_config["grant_type"] = "refresh_token"

if conn_config.get("authenticator") == "oauth":
if extra_dict.get("azure_conn_id"):
conn_config["token"] = self.get_azure_oauth_token(extra_dict["azure_conn_id"])
else:
token_endpoint = self._get_field(extra_dict, "token_endpoint") or ""
conn_config["azure_conn_id"] = extra_dict.get("azure_conn_id")

if not extra_dict.get("azure_conn_id"):
conn_config["token_endpoint"] = self._get_field(extra_dict, "token_endpoint") or ""
conn_config["scope"] = self._get_field(extra_dict, "scope")
conn_config["client_id"] = conn.login
conn_config["client_secret"] = conn.password

conn_config["token"] = self.get_oauth_token(
conn_config=conn_config,
token_endpoint=token_endpoint,
grant_type=extra_dict.get("grant_type", "refresh_token"),
)

conn_config.pop("login", None)
conn_config.pop("user", None)
conn_config.pop("password", None)

# configure custom target hostname and port, if specified
snowflake_host = extra_dict.get("host")
snowflake_port = extra_dict.get("port")
Expand All @@ -423,9 +429,80 @@ def _get_conn_params(self) -> dict[str, str | None]:

return conn_config

def _get_valid_oauth_token(
self,
*,
conn_config: dict[str, Any],
token_endpoint: str | None,
grant_type: str,
) -> str:
"""
Return a valid OAuth access token.

This also updates the internal OAuth token cache and token expiry timestamp.
"""
# Check validity using current timestamp.
now = timezone.utcnow()

if (
self._oauth_token is not None
and self._oauth_token_expires_at is not None
and now < self._oauth_token_expires_at
):
return self._oauth_token

url = token_endpoint or f"https://{conn_config['account']}.snowflakecomputing.com/oauth/token-request"

data = {
"grant_type": grant_type,
"redirect_uri": conn_config.get("redirect_uri", "https://localhost.com"),
}

scope = conn_config.get("scope")

if scope:
data["scope"] = scope

if grant_type == "refresh_token":
data |= {
"refresh_token": conn_config["refresh_token"],
}
elif grant_type == "client_credentials":
pass # no setup necessary for client credentials grant.
else:
raise ValueError(f"Unknown grant_type: {grant_type}")

response = requests.post(
url,
data=data,
headers={
"Content-Type": "application/x-www-form-urlencoded",
},
auth=HTTPBasicAuth(conn_config["client_id"], conn_config["client_secret"]), # type: ignore[arg-type]
timeout=OAUTH_REQUEST_TIMEOUT,
)

try:
response.raise_for_status()
except requests.exceptions.HTTPError as e: # pragma: no cover
msg = f"Response: {e.response.content.decode()} Status Code: {e.response.status_code}"
raise AirflowException(msg)

token = response.json()["access_token"]
expires_in = int(response.json()["expires_in"])

# Capture issue timestamp after access token is retrieved.
issued_at = timezone.utcnow()

# Persist retrieved access token and expiration timestamp.
self._oauth_token = token
self._oauth_token_expires_at = issued_at + timedelta(seconds=max(expires_in - OAUTH_EXPIRY_BUFFER, 0))

return token

def get_uri(self) -> str:
"""Override DbApiHook get_uri method for get_sqlalchemy_engine()."""
conn_params = self._get_conn_params
conn_params = self._get_conn_params()
return self._conn_params_to_sqlalchemy_uri(conn_params)

def _conn_params_to_sqlalchemy_uri(self, conn_params: dict) -> str:
Expand All @@ -449,7 +526,7 @@ def _conn_params_to_sqlalchemy_uri(self, conn_params: dict) -> str:

def get_conn(self) -> SnowflakeConnection:
"""Return a snowflake.connection object."""
conn_config = self._get_conn_params
conn_config = self._get_conn_params()
conn = connector.connect(**conn_config)
return conn

Expand All @@ -461,7 +538,7 @@ def get_sqlalchemy_engine(self, engine_kwargs=None):
:return: the created engine.
"""
engine_kwargs = engine_kwargs or {}
conn_params = self._get_conn_params
conn_params = self._get_conn_params()
if "insecure_mode" in conn_params:
engine_kwargs.setdefault("connect_args", {})
engine_kwargs["connect_args"]["insecure_mode"] = True
Expand All @@ -488,7 +565,7 @@ def get_snowpark_session(self):
from airflow import __version__ as airflow_version
from airflow.providers.snowflake import __version__ as provider_version

conn_config = self._get_conn_params
conn_config = self._get_conn_params()
session = Session.builder.configs(conn_config).create()
# add query tag for observability
session.update_query_tag(
Expand Down Expand Up @@ -654,7 +731,7 @@ def get_openlineage_database_dialect(self, _) -> str:
return "snowflake"

def get_openlineage_default_schema(self) -> str | None:
return self._get_conn_params["schema"]
return self._get_conn_params()["schema"]

def _get_openlineage_authority(self, _) -> str | None:
uri = fix_snowflake_sqlalchemy_uri(self.get_uri())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def execute_query(
the statement with these specified values.
"""
self.query_ids = []
conn_config = self._get_conn_params
conn_config = self._get_conn_params()

req_id = uuid.uuid4()
url = f"{self.account_identifier}.snowflakecomputing.com/api/v2/statements"
Expand Down Expand Up @@ -206,7 +206,7 @@ def execute_query(

def get_headers(self) -> dict[str, Any]:
"""Form auth headers based on either OAuth token or JWT token from private key."""
conn_config = self._get_conn_params
conn_config = self._get_conn_params()

# Use OAuth if refresh_token and client_id and client_secret are provided
if all(
Expand Down
Loading