Skip to content

Commit fd39795

Browse files
author
vishalup29
committed
Snowflake: Support passing OAuth scope parameter in client_credentials flow
1 parent b48c909 commit fd39795

File tree

3 files changed

+90
-4
lines changed

3 files changed

+90
-4
lines changed

providers/snowflake/docs/connections/snowflake.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ Extra (optional)
6060
* ``authenticator``: To connect using OAuth set this parameter ``oauth``.
6161
* ``token_endpoint``: Specify token endpoint for external OAuth provider.
6262
* ``grant_type``: Specify grant type for OAuth authentication. Currently supported: ``refresh_token`` (default), ``client_credentials``.
63+
* ``oauth_scope``: Optional OAuth scope to include when using the ``client_credentials`` grant type.
64+
Some identity providers (e.g., Okta, Auth0) require a scope when issuing client credentials tokens.
6365
* ``refresh_token``: Specify refresh_token for OAuth connection.
6466
* ``azure_conn_id``: Azure Connection ID to be used for retrieving the OAuth token using Azure Entra authentication. Login and Password fields aren't required when using this method. Scope for the Azure OAuth token can be set in the config option ``azure_oauth_scope`` under the section ``[snowflake]``. Requires `apache-airflow-providers-microsoft-azure>=12.8.0`.
6567
* ``private_key_file``: Specify the path to the private key file.

providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -224,11 +224,15 @@ def get_oauth_token(
224224
}
225225

226226
if grant_type == "refresh_token":
227-
data |= {
228-
"refresh_token": conn_config["refresh_token"],
229-
}
227+
# Older provider versions may not supply refresh_token; avoid KeyError
228+
refresh_token = conn_config.get("refresh_token")
229+
if not refresh_token:
230+
raise AirflowException("grant_type=refresh_token requires `refresh_token` in extra")
231+
data["refresh_token"] = refresh_token
230232
elif grant_type == "client_credentials":
231-
pass # no setup necessary for client credentials grant.
233+
oauth_scope = conn_config.get("oauth_scope")
234+
if oauth_scope:
235+
data["scope"] = oauth_scope
232236
else:
233237
raise ValueError(f"Unknown grant_type: {grant_type}")
234238

@@ -379,6 +383,10 @@ def _get_conn_params(self) -> dict[str, str | None]:
379383
conn_config.pop("password", None)
380384

381385
refresh_token = self._get_field(extra_dict, "refresh_token") or ""
386+
oauth_scope = self._get_field(extra_dict, "oauth_scope")
387+
if oauth_scope:
388+
conn_config["oauth_scope"] = oauth_scope
389+
382390
if refresh_token:
383391
conn_config["refresh_token"] = refresh_token
384392
conn_config["authenticator"] = "oauth"

providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1133,3 +1133,79 @@ def __init__(self, sdk_client, conn_id="azure_default"):
11331133

11341134
# Check AzureBaseHook initialization
11351135
mock_connection_class.get.assert_called_once_with(azure_conn_id)
1136+
1137+
@mock.patch("airflow.providers.snowflake.hooks.snowflake.HTTPBasicAuth")
1138+
@mock.patch("requests.post")
1139+
@mock.patch(
1140+
"airflow.providers.snowflake.hooks.snowflake.SnowflakeHook._get_conn_params",
1141+
new_callable=PropertyMock,
1142+
)
1143+
def test_get_oauth_token_client_credentials_without_scope(
1144+
self, mock_conn_params, requests_post,mock_auth
1145+
):
1146+
"""Test client_credentials flow without oauth_scope."""
1147+
mock_conn_params.return_value = {
1148+
**CONN_PARAMS_OAUTH_BASE,
1149+
"grant_type": "client_credentials",
1150+
}
1151+
1152+
basic_auth = {"Authorization": "Basic test"}
1153+
mock_auth.return_value = basic_auth
1154+
1155+
# Response mock
1156+
requests_post.return_value.status_code = 200
1157+
requests_post.return_value.json = lambda: {"access_token": "token123"}
1158+
1159+
hook = SnowflakeHook(snowflake_conn_id="test_conn")
1160+
hook.get_oauth_token(
1161+
conn_config=mock_conn_params.return_value,
1162+
grant_type="client_credentials",
1163+
)
1164+
1165+
requests_post.assert_called_once_with(
1166+
f"https://{CONN_PARAMS_OAUTH_BASE['account']}.snowflakecomputing.com/oauth/token-request",
1167+
data={
1168+
"grant_type": "client_credentials",
1169+
"redirect_uri": "https://localhost.com",
1170+
},
1171+
headers={"Content-Type": "application/x-www-form-urlencoded"},
1172+
auth=basic_auth,
1173+
)
1174+
1175+
@mock.patch("airflow.providers.snowflake.hooks.snowflake.HTTPBasicAuth")
1176+
@mock.patch("requests.post")
1177+
@mock.patch(
1178+
"airflow.providers.snowflake.hooks.snowflake.SnowflakeHook._get_conn_params",
1179+
new_callable=PropertyMock,
1180+
)
1181+
def test_get_oauth_token_client_credentials_with_scope(self, mock_conn_params, requests_post, mock_auth):
1182+
"""Test client_credentials flow WITH oauth_scope."""
1183+
mock_conn_params.return_value = {
1184+
**CONN_PARAMS_OAUTH_BASE,
1185+
"grant_type": "client_credentials",
1186+
"oauth_scope": "custom_scope",
1187+
}
1188+
1189+
basic_auth = {"Authorization": "Basic test"}
1190+
mock_auth.return_value = basic_auth
1191+
1192+
# Response mock
1193+
requests_post.return_value.status_code = 200
1194+
requests_post.return_value.json = lambda: {"access_token": "token123"}
1195+
1196+
hook = SnowflakeHook(snowflake_conn_id="test_conn")
1197+
hook.get_oauth_token(
1198+
conn_config=mock_conn_params.return_value,
1199+
grant_type="client_credentials",
1200+
)
1201+
1202+
requests_post.assert_called_once_with(
1203+
f"https://{CONN_PARAMS_OAUTH_BASE['account']}.snowflakecomputing.com/oauth/token-request",
1204+
data={
1205+
"grant_type": "client_credentials",
1206+
"redirect_uri": "https://localhost.com",
1207+
"scope": "custom_scope",
1208+
},
1209+
headers={"Content-Type": "application/x-www-form-urlencoded"},
1210+
auth=basic_auth,
1211+
)

0 commit comments

Comments
 (0)