Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -189,9 +189,24 @@ def _get_field(self, extra_dict, field_name):
return extra_dict[field_name] or None
return extra_dict.get(backcompat_key) or None

def get_oauth_token(self, conn_config: dict) -> str:
@property
def account_identifier(self) -> str:
"""Returns snowflake account identifier."""
conn_config = self._get_conn_params
account_identifier = f"https://{conn_config['account']}"

if conn_config["region"]:
account_identifier += f".{conn_config['region']}"

return account_identifier

def get_oauth_token(self, conn_config: dict | None = None) -> str:
"""Generate temporary OAuth access token using refresh token in connection details."""
url = f"https://{conn_config['account']}.snowflakecomputing.com/oauth/token-request"
if conn_config is None:
conn_config = self._get_conn_params

url = f"{self.account_identifier}.snowflakecomputing.com/oauth/token-request"

data = {
"grant_type": "refresh_token",
"refresh_token": conn_config["refresh_token"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import base64
import uuid
import warnings
from datetime import timedelta
from pathlib import Path
from typing import Any
Expand All @@ -26,9 +27,8 @@
import requests
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from requests.auth import HTTPBasicAuth

from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook
from airflow.providers.snowflake.utils.sql_api_generate_jwt import JWTGenerator

Expand Down Expand Up @@ -84,17 +84,6 @@ def __init__(
super().__init__(snowflake_conn_id, *args, **kwargs)
self.private_key: Any = None

@property
def account_identifier(self) -> str:
"""Returns snowflake account identifier."""
conn_config = self._get_conn_params
account_identifier = f"https://{conn_config['account']}"

if conn_config["region"]:
account_identifier += f".{conn_config['region']}"

return account_identifier

def get_private_key(self) -> None:
"""Get the private key from snowflake connection."""
conn = self.get_connection(self.snowflake_conn_id)
Expand Down Expand Up @@ -233,29 +222,14 @@ def get_headers(self) -> dict[str, Any]:
}
return headers

def get_oauth_token(self, conn_config: dict[str, Any]) -> str:
def get_oauth_token(self, conn_config: dict[str, Any] | None = None) -> str:
"""Generate temporary OAuth access token using refresh token in connection details."""
url = f"{self.account_identifier}.snowflakecomputing.com/oauth/token-request"
data = {
"grant_type": "refresh_token",
"refresh_token": conn_config["refresh_token"],
"redirect_uri": conn_config.get("redirect_uri", "https://localhost.com"),
}
response = requests.post(
url,
data=data,
headers={
"Content-Type": "application/x-www-form-urlencoded",
},
auth=HTTPBasicAuth(conn_config["client_id"], conn_config["client_secret"]), # type: ignore[arg-type]
warnings.warn(
"This method is deprecated. Please use `get_oauth_token` method from `SnowflakeHook` instead. ",
AirflowProviderDeprecationWarning,
stacklevel=2,
)

try:
response.raise_for_status()
except requests.exceptions.HTTPError as e: # pragma: no cover
msg = f"Response: {e.response.content.decode()} Status Code: {e.response.status_code}"
raise AirflowException(msg)
return response.json()["access_token"]
return super().get_oauth_token(conn_config=conn_config)

def get_request_url_header_params(self, query_id: str) -> tuple[dict[str, Any], dict[str, Any], str]:
"""
Expand Down
36 changes: 26 additions & 10 deletions providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,11 @@ def test_get_conn_params_should_fail_on_invalid_key(self):
SnowflakeHook(snowflake_conn_id="test_conn").get_conn()

@mock.patch("requests.post")
def test_get_conn_params_should_support_oauth(self, requests_post):
@mock.patch(
"airflow.providers.snowflake.hooks.snowflake.SnowflakeHook._get_conn_params",
new_callable=PropertyMock,
)
def test_get_conn_params_should_support_oauth(self, mock_get_conn_params, requests_post):
requests_post.return_value = Mock(
status_code=200,
json=lambda: {
Expand All @@ -533,15 +537,27 @@ def test_get_conn_params_should_support_oauth(self, requests_post):
"region": "af_region",
"role": "af_role",
"refresh_token": "secrettoken",
"authenticator": "oauth",
},
}
mock_get_conn_params.return_value = connection_kwargs
with mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()):
hook = SnowflakeHook(snowflake_conn_id="test_conn")
assert "user" not in hook._get_conn_params
assert "password" not in hook._get_conn_params
assert "refresh_token" in hook._get_conn_params
assert "token" in hook._get_conn_params
assert hook._get_conn_params["authenticator"] == "oauth"
conn_params = hook._get_conn_params

conn_params_keys = conn_params.keys()
conn_params_extra = conn_params.get("extra", {})
conn_params_extra_keys = conn_params_extra.keys()

assert "authenticator" in conn_params_extra_keys
assert conn_params_extra["authenticator"] == "oauth"

assert "user" not in conn_params_keys
assert "password" in conn_params_keys
assert "refresh_token" in conn_params_extra_keys
# Mandatory fields to generate account_identifier `https://<account>.<region>`
assert "region" in conn_params_extra_keys
assert "account" in conn_params_extra_keys

def test_should_add_partner_info(self):
with mock.patch.dict(
Expand Down Expand Up @@ -885,19 +901,19 @@ def test_get_snowpark_session(self, mock_session_builder):
)
def test_get_oauth_token(self, mock_conn_param, requests_post, mock_auth):
"""Test get_oauth_token method makes the right http request"""
BASIC_AUTH = {"Authorization": "Basic usernamepassword"}
basic_auth = {"Authorization": "Basic usernamepassword"}
mock_conn_param.return_value = CONN_PARAMS_OAUTH
requests_post.return_value.status_code = 200
mock_auth.return_value = BASIC_AUTH
mock_auth.return_value = basic_auth
hook = SnowflakeHook(snowflake_conn_id="mock_conn_id")
hook.get_oauth_token(conn_config=CONN_PARAMS_OAUTH)
requests_post.assert_called_once_with(
f"https://{CONN_PARAMS_OAUTH['account']}.snowflakecomputing.com/oauth/token-request",
f"https://{CONN_PARAMS_OAUTH['account']}.{CONN_PARAMS_OAUTH['region']}.snowflakecomputing.com/oauth/token-request",
data={
"grant_type": "refresh_token",
"refresh_token": CONN_PARAMS_OAUTH["refresh_token"],
"redirect_uri": "https://localhost.com",
},
headers={"Content-Type": "application/x-www-form-urlencoded"},
auth=BASIC_AUTH,
auth=basic_auth,
)
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from cryptography.hazmat.primitives.asymmetric import rsa
from responses import RequestsMock

from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.models import Connection
from airflow.providers.snowflake.hooks.snowflake_sql_api import (
SnowflakeSqlApiHook,
Expand Down Expand Up @@ -352,20 +352,21 @@ def test_get_headers_should_support_oauth(self, mock_conn_param, mock_oauth_toke
result = hook.get_headers()
assert result == HEADERS_OAUTH

@mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.HTTPBasicAuth")
@mock.patch("airflow.providers.snowflake.hooks.snowflake.HTTPBasicAuth")
@mock.patch("requests.post")
@mock.patch(
"airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook._get_conn_params",
new_callable=PropertyMock,
)
def test_get_oauth_token(self, mock_conn_param, requests_post, mock_auth):
"""Test get_oauth_token method makes the right http request"""
BASIC_AUTH = {"Authorization": "Basic usernamepassword"}
basic_auth = {"Authorization": "Basic usernamepassword"}
mock_conn_param.return_value = CONN_PARAMS_OAUTH
requests_post.return_value.status_code = 200
mock_auth.return_value = BASIC_AUTH
mock_auth.return_value = basic_auth
hook = SnowflakeSqlApiHook(snowflake_conn_id="mock_conn_id")
hook.get_oauth_token(CONN_PARAMS_OAUTH)
with pytest.warns(expected_warning=AirflowProviderDeprecationWarning):
hook.get_oauth_token(CONN_PARAMS_OAUTH)
requests_post.assert_called_once_with(
f"https://{CONN_PARAMS_OAUTH['account']}.{CONN_PARAMS_OAUTH['region']}.snowflakecomputing.com/oauth/token-request",
data={
Expand All @@ -374,7 +375,7 @@ def test_get_oauth_token(self, mock_conn_param, requests_post, mock_auth):
"redirect_uri": "https://localhost.com",
},
headers={"Content-Type": "application/x-www-form-urlencoded"},
auth=BASIC_AUTH,
auth=basic_auth,
)

@pytest.fixture
Expand Down