Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion providers/snowflake/docs/connections/snowflake.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,11 @@ Configuring the Connection
--------------------------

Login
Specify the snowflake username.
Specify the snowflake username. For OAuth, the OAuth Client ID.

Password
Specify the snowflake password. For public key authentication, the passphrase for the private key.
For OAuth, the OAuth Client Secret.

Schema (optional)
Specify the snowflake schema to be used.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@
from typing import TYPE_CHECKING, Any, Callable, TypeVar, overload
from urllib.parse import urlparse

import requests
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from requests.auth import HTTPBasicAuth
from snowflake import connector
from snowflake.connector import DictCursor, SnowflakeConnection, util_text
from snowflake.sqlalchemy import URL
Expand Down Expand Up @@ -185,6 +187,30 @@ def _get_field(self, extra_dict, field_name):
return extra_dict[field_name] or None
return extra_dict.get(backcompat_key) or None

def get_oauth_token(self, conn_config: dict) -> str:
"""Generate temporary OAuth access token using refresh token in connection details."""
url = f"https://{conn_config['account']}.snowflakecomputing.com/oauth/token-request"
data = {
"grant_type": "refresh_token",
"refresh_token": conn_config["refresh_token"],
"redirect_uri": conn_config.get("redirect_uri", "https://localhost.com"),
}
response = requests.post(
url,
data=data,
headers={
"Content-Type": "application/x-www-form-urlencoded",
},
auth=HTTPBasicAuth(conn_config["client_id"], conn_config["client_secret"]), # type: ignore[arg-type]
)

try:
response.raise_for_status()
except requests.exceptions.HTTPError as e: # pragma: no cover
msg = f"Response: {e.response.content.decode()} Status Code: {e.response.status_code}"
raise AirflowException(msg)
return response.json()["access_token"]

@cached_property
def _get_conn_params(self) -> dict[str, str | None]:
"""
Expand Down Expand Up @@ -289,8 +315,11 @@ def _get_conn_params(self) -> dict[str, str | None]:
conn_config["client_id"] = conn.login
conn_config["client_secret"] = conn.password
conn_config.pop("login", None)
conn_config.pop("user", None)
conn_config.pop("password", None)

conn_config["token"] = self.get_oauth_token(conn_config=conn_config)

# configure custom target hostname and port, if specified
snowflake_host = extra_dict.get("host")
snowflake_port = extra_dict.get("port")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from copy import deepcopy
from typing import TYPE_CHECKING, Any
from unittest import mock
from unittest.mock import Mock, PropertyMock

import pytest
from cryptography.hazmat.backends import default_backend
Expand Down Expand Up @@ -51,6 +52,21 @@
},
}

CONN_PARAMS_OAUTH = {
"account": "airflow",
"application": "AIRFLOW",
"authenticator": "oauth",
"database": "db",
"client_id": "test_client_id",
"client_secret": "test_client_pw",
"refresh_token": "secrettoken",
"region": "af_region",
"role": "af_role",
"schema": "public",
"session_parameters": None,
"warehouse": "af_wh",
}


@pytest.fixture
def non_encrypted_temporary_private_key(tmp_path: Path) -> Path:
Expand Down Expand Up @@ -483,6 +499,39 @@ def test_get_conn_params_should_fail_on_invalid_key(self):
):
SnowflakeHook(snowflake_conn_id="test_conn").get_conn()

@mock.patch("requests.post")
def test_get_conn_params_should_support_oauth(self, requests_post):
requests_post.return_value = Mock(
status_code=200,
json=lambda: {
"access_token": "supersecretaccesstoken",
"expires_in": 600,
"refresh_token": "secrettoken",
"token_type": "Bearer",
"username": "test_user",
},
)
connection_kwargs = {
**BASE_CONNECTION_KWARGS,
"login": "test_client_id",
"password": "test_client_secret",
"extra": {
"database": "db",
"account": "airflow",
"warehouse": "af_wh",
"region": "af_region",
"role": "af_role",
"refresh_token": "secrettoken",
},
}
with mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()):
hook = SnowflakeHook(snowflake_conn_id="test_conn")
assert "user" not in hook._get_conn_params
assert "password" not in hook._get_conn_params
assert "refresh_token" in hook._get_conn_params
assert "token" in hook._get_conn_params
assert hook._get_conn_params["authenticator"] == "oauth"

def test_should_add_partner_info(self):
with mock.patch.dict(
"os.environ",
Expand Down Expand Up @@ -816,3 +865,28 @@ def test_get_snowpark_session(self, mock_session_builder):
"airflow_provider_version": provider_version,
}
)

@mock.patch("airflow.providers.snowflake.hooks.snowflake.HTTPBasicAuth")
@mock.patch("requests.post")
@mock.patch(
"airflow.providers.snowflake.hooks.snowflake.SnowflakeHook._get_conn_params",
new_callable=PropertyMock,
)
def test_get_oauth_token(self, mock_conn_param, requests_post, mock_auth):
"""Test get_oauth_token method makes the right http request"""
BASIC_AUTH = {"Authorization": "Basic usernamepassword"}
mock_conn_param.return_value = CONN_PARAMS_OAUTH
requests_post.return_value.status_code = 200
mock_auth.return_value = BASIC_AUTH
hook = SnowflakeHook(snowflake_conn_id="mock_conn_id")
hook.get_oauth_token(conn_config=CONN_PARAMS_OAUTH)
requests_post.assert_called_once_with(
f"https://{CONN_PARAMS_OAUTH['account']}.snowflakecomputing.com/oauth/token-request",
data={
"grant_type": "refresh_token",
"refresh_token": CONN_PARAMS_OAUTH["refresh_token"],
"redirect_uri": "https://localhost.com",
},
headers={"Content-Type": "application/x-www-form-urlencoded"},
auth=BASIC_AUTH,
)