Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions providers/snowflake/docs/connections/snowflake.rst
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ Extra (optional)
* ``warehouse``: Snowflake warehouse name.
* ``role``: Snowflake role.
* ``authenticator``: To connect using OAuth set this parameter ``oauth``.
* ``token_endpoint``: Specify token endpoint for external OAuth provider.
* ``grant_type``: Specify grant type for OAuth authentication. Currently supported: ``refresh_token`` (default), ``client_credentials``.
* ``refresh_token``: Specify refresh_token for OAuth connection.
* ``private_key_file``: Specify the path to the private key file.
* ``private_key_content``: Specify the content of the private key file in base64 encoded format. You can use the following Python code to encode the private key:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,9 @@ def get_ui_field_behaviour(cls) -> dict[str, Any]:
"session_parameters": "session parameters",
"client_request_mfa_token": "client request mfa token",
"client_store_temporary_credential": "client store temporary credential (externalbrowser mode)",
"grant_type": "refresh_token client_credentials",
"token_endpoint": "token endpoint",
"refresh_token": "refresh token",
},
indent=1,
),
Expand Down Expand Up @@ -200,18 +203,32 @@ def account_identifier(self) -> str:

return account_identifier

def get_oauth_token(self, conn_config: dict | None = None) -> str:
def get_oauth_token(
self,
conn_config: dict | None = None,
token_endpoint: str | None = None,
grant_type: str = "refresh_token",
) -> str:
"""Generate temporary OAuth access token using refresh token in connection details."""
if conn_config is None:
conn_config = self._get_conn_params

url = f"https://{conn_config['account']}.snowflakecomputing.com/oauth/token-request"
url = token_endpoint or f"https://{conn_config['account']}.snowflakecomputing.com/oauth/token-request"

data = {
"grant_type": "refresh_token",
"refresh_token": conn_config["refresh_token"],
"grant_type": grant_type,
"redirect_uri": conn_config.get("redirect_uri", "https://localhost.com"),
}

if grant_type == "refresh_token":
data |= {
"refresh_token": conn_config["refresh_token"],
}
elif grant_type == "client_credentials":
pass # no setup necessary for client credentials grant.
else:
raise ValueError(f"Unknown grant_type: {grant_type}")

response = requests.post(
url,
data=data,
Expand All @@ -226,7 +243,8 @@ def get_oauth_token(self, conn_config: dict | None = None) -> str:
except requests.exceptions.HTTPError as e: # pragma: no cover
msg = f"Response: {e.response.content.decode()} Status Code: {e.response.status_code}"
raise AirflowException(msg)
return response.json()["access_token"]
token = response.json()["access_token"]
return token

@cached_property
def _get_conn_params(self) -> dict[str, str | None]:
Expand Down Expand Up @@ -329,14 +347,21 @@ def _get_conn_params(self) -> dict[str, str | None]:
if refresh_token:
conn_config["refresh_token"] = refresh_token
conn_config["authenticator"] = "oauth"

if conn_config.get("authenticator") == "oauth":
token_endpoint = self._get_field(extra_dict, "token_endpoint") or ""
conn_config["client_id"] = conn.login
conn_config["client_secret"] = conn.password
conn_config["token"] = self.get_oauth_token(
conn_config=conn_config,
token_endpoint=token_endpoint,
grant_type=extra_dict.get("grant_type", "refresh_token"),
)

conn_config.pop("login", None)
conn_config.pop("user", None)
conn_config.pop("password", None)

conn_config["token"] = self.get_oauth_token(conn_config=conn_config)

# configure custom target hostname and port, if specified
snowflake_host = extra_dict.get("host")
snowflake_port = extra_dict.get("port")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,14 +222,21 @@ def get_headers(self) -> dict[str, Any]:
}
return headers

def get_oauth_token(self, conn_config: dict[str, Any] | None = None) -> str:
def get_oauth_token(
self,
conn_config: dict[str, Any] | None = None,
token_endpoint: str | None = None,
grant_type: str = "refresh_token",
) -> str:
"""Generate temporary OAuth access token using refresh token in connection details."""
warnings.warn(
"This method is deprecated. Please use `get_oauth_token` method from `SnowflakeHook` instead. ",
AirflowProviderDeprecationWarning,
stacklevel=2,
)
return super().get_oauth_token(conn_config=conn_config)
return super().get_oauth_token(
conn_config=conn_config, token_endpoint=token_endpoint, grant_type=grant_type
)

def get_request_url_header_params(self, query_id: str) -> tuple[dict[str, Any], dict[str, Any], str]:
"""
Expand Down
139 changes: 137 additions & 2 deletions providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,21 +53,22 @@
},
}

CONN_PARAMS_OAUTH = {
CONN_PARAMS_OAUTH_BASE = {
"account": "airflow",
"application": "AIRFLOW",
"authenticator": "oauth",
"database": "db",
"client_id": "test_client_id",
"client_secret": "test_client_pw",
"refresh_token": "secrettoken",
"region": "af_region",
"role": "af_role",
"schema": "public",
"session_parameters": None,
"warehouse": "af_wh",
}

CONN_PARAMS_OAUTH = CONN_PARAMS_OAUTH_BASE | {"refresh_token": "secrettoken"}


@pytest.fixture
def unencrypted_temporary_private_key(tmp_path: Path) -> Path:
Expand Down Expand Up @@ -559,6 +560,112 @@ def test_get_conn_params_should_support_oauth(self, mock_get_conn_params, reques
assert "region" in conn_params_extra_keys
assert "account" in conn_params_extra_keys

@mock.patch("requests.post")
@mock.patch(
"airflow.providers.snowflake.hooks.snowflake.SnowflakeHook._get_conn_params",
new_callable=PropertyMock,
)
def test_get_conn_params_should_support_oauth_with_token_endpoint(
self, mock_get_conn_params, requests_post
):
requests_post.return_value = Mock(
status_code=200,
json=lambda: {
"access_token": "supersecretaccesstoken",
"expires_in": 600,
"refresh_token": "secrettoken",
"token_type": "Bearer",
"username": "test_user",
},
)
connection_kwargs = {
**BASE_CONNECTION_KWARGS,
"login": "test_client_id",
"password": "test_client_secret",
"extra": {
"database": "db",
"account": "airflow",
"warehouse": "af_wh",
"region": "af_region",
"role": "af_role",
"refresh_token": "secrettoken",
"authenticator": "oauth",
"token_endpoint": "https://www.example.com/oauth/token",
},
}
mock_get_conn_params.return_value = connection_kwargs
with mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()):
hook = SnowflakeHook(snowflake_conn_id="test_conn")
conn_params = hook._get_conn_params

conn_params_keys = conn_params.keys()
conn_params_extra = conn_params.get("extra", {})
conn_params_extra_keys = conn_params_extra.keys()

assert "authenticator" in conn_params_extra_keys
assert conn_params_extra["authenticator"] == "oauth"
assert conn_params_extra["token_endpoint"] == "https://www.example.com/oauth/token"

assert "user" not in conn_params_keys
assert "password" in conn_params_keys
assert "refresh_token" in conn_params_extra_keys
# Mandatory fields to generate account_identifier `https://<account>.<region>`
assert "region" in conn_params_extra_keys
assert "account" in conn_params_extra_keys

@mock.patch("requests.post")
@mock.patch(
"airflow.providers.snowflake.hooks.snowflake.SnowflakeHook._get_conn_params",
new_callable=PropertyMock,
)
def test_get_conn_params_should_support_oauth_with_client_credentials(
self, mock_get_conn_params, requests_post
):
requests_post.return_value = Mock(
status_code=200,
json=lambda: {
"access_token": "supersecretaccesstoken",
"expires_in": 600,
"refresh_token": "secrettoken",
"token_type": "Bearer",
"username": "test_user",
},
)
connection_kwargs = {
**BASE_CONNECTION_KWARGS,
"login": "test_client_id",
"password": "test_client_secret",
"extra": {
"database": "db",
"account": "airflow",
"warehouse": "af_wh",
"region": "af_region",
"role": "af_role",
"authenticator": "oauth",
"token_endpoint": "https://www.example.com/oauth/token",
"grant_type": "client_credentials",
},
}
mock_get_conn_params.return_value = connection_kwargs
with mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()):
hook = SnowflakeHook(snowflake_conn_id="test_conn")
conn_params = hook._get_conn_params

conn_params_keys = conn_params.keys()
conn_params_extra = conn_params.get("extra", {})
conn_params_extra_keys = conn_params_extra.keys()

assert "authenticator" in conn_params_extra_keys
assert conn_params_extra["authenticator"] == "oauth"
assert conn_params_extra["grant_type"] == "client_credentials"

assert "user" not in conn_params_keys
assert "password" in conn_params_keys
assert "refresh_token" not in conn_params_extra_keys
# Mandatory fields to generate account_identifier `https://<account>.<region>`
assert "region" in conn_params_extra_keys
assert "account" in conn_params_extra_keys

def test_should_add_partner_info(self):
with mock.patch.dict(
"os.environ",
Expand Down Expand Up @@ -917,3 +1024,31 @@ def test_get_oauth_token(self, mock_conn_param, requests_post, mock_auth):
headers={"Content-Type": "application/x-www-form-urlencoded"},
auth=basic_auth,
)

@mock.patch("airflow.providers.snowflake.hooks.snowflake.HTTPBasicAuth")
@mock.patch("requests.post")
@mock.patch(
"airflow.providers.snowflake.hooks.snowflake.SnowflakeHook._get_conn_params",
new_callable=PropertyMock,
)
def test_get_oauth_token_with_token_endpoint(self, mock_conn_param, requests_post, mock_auth):
"""Test get_oauth_token method makes the right http request"""
basic_auth = {"Authorization": "Basic usernamepassword"}
token_endpoint = "https://example.com/oauth/token"
mock_conn_param.return_value = CONN_PARAMS_OAUTH
requests_post.return_value.status_code = 200
mock_auth.return_value = basic_auth

hook = SnowflakeHook(snowflake_conn_id="mock_conn_id")
hook.get_oauth_token(conn_config=CONN_PARAMS_OAUTH, token_endpoint=token_endpoint)

requests_post.assert_called_once_with(
token_endpoint,
data={
"grant_type": "refresh_token",
"refresh_token": CONN_PARAMS_OAUTH["refresh_token"],
"redirect_uri": "https://localhost.com",
},
headers={"Content-Type": "application/x-www-form-urlencoded"},
auth=basic_auth,
)