Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions providers/snowflake/docs/connections/snowflake.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ Extra (optional)
* ``authenticator``: To connect using OAuth set this parameter ``oauth``.
* ``token_endpoint``: Specify token endpoint for external OAuth provider.
* ``grant_type``: Specify grant type for OAuth authentication. Currently supported: ``refresh_token`` (default), ``client_credentials``.
* ``scope``: Specify OAuth scope to include in the access token request for any OAuth grant type.
* ``refresh_token``: Specify refresh_token for OAuth connection.
* ``azure_conn_id``: Azure Connection ID to be used for retrieving the OAuth token using Azure Entra authentication. Login and Password fields aren't required when using this method. Scope for the Azure OAuth token can be set in the config option ``azure_oauth_scope`` under the section ``[snowflake]``. Requires `apache-airflow-providers-microsoft-azure>=12.8.0`.
* ``private_key_file``: Specify the path to the private key file.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def get_ui_field_behaviour(cls) -> dict[str, Any]:
"grant_type": "refresh_token client_credentials",
"token_endpoint": "token endpoint",
"refresh_token": "refresh token",
"scope": "scope",
},
indent=1,
),
Expand Down Expand Up @@ -223,6 +224,11 @@ def get_oauth_token(
"redirect_uri": conn_config.get("redirect_uri", "https://localhost.com"),
}

scope = conn_config.get("scope")

if scope:
data["scope"] = scope

if grant_type == "refresh_token":
data |= {
"refresh_token": conn_config["refresh_token"],
Expand Down Expand Up @@ -388,8 +394,10 @@ def _get_conn_params(self) -> dict[str, str | None]:
conn_config["token"] = self.get_azure_oauth_token(extra_dict["azure_conn_id"])
else:
token_endpoint = self._get_field(extra_dict, "token_endpoint") or ""
conn_config["scope"] = self._get_field(extra_dict, "scope")
conn_config["client_id"] = conn.login
conn_config["client_secret"] = conn.password

conn_config["token"] = self.get_oauth_token(
conn_config=conn_config,
token_endpoint=token_endpoint,
Expand Down
108 changes: 108 additions & 0 deletions providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,6 +704,44 @@ def test_get_conn_params_should_support_oauth_with_azure_conn_id(self, mocker):
assert "region" in conn_params
assert "account" in conn_params

@mock.patch("requests.post")
def test_get_conn_params_include_scope(self, mock_requests_post):
"""
Verify that `_get_conn_params` includes the `scope` field when it is present
in the connection extras.
"""
mock_requests_post.return_value = Mock(
status_code=200,
json=lambda: {
"access_token": "dummy",
"expires_in": 600,
"token_type": "Bearer",
"username": "test_user",
},
)

connection_kwargs = {
**BASE_CONNECTION_KWARGS,
"login": "test_client_id",
"password": "test_client_secret",
"extra": {
"account": "airflow",
"authenticator": "oauth",
"grant_type": "client_credentials",
"scope": "default",
},
}

with mock.patch.dict(
"os.environ",
{"AIRFLOW_CONN_TEST_CONN": Connection(**connection_kwargs).get_uri()},
):
hook = SnowflakeHook(snowflake_conn_id="test_conn")
params = hook._get_conn_params
mock_requests_post.assert_called_once()
assert "scope" in params
assert params["scope"] == "default"

def test_should_add_partner_info(self):
with mock.patch.dict(
"os.environ",
Expand Down Expand Up @@ -1133,3 +1171,73 @@ def __init__(self, sdk_client, conn_id="azure_default"):

# Check AzureBaseHook initialization
mock_connection_class.get.assert_called_once_with(azure_conn_id)

@mock.patch("requests.post")
def test_get_oauth_token_with_scope(self, mock_requests_post):
"""
Verify that `get_oauth_token` returns an access token and includes the
provided scope in the outgoing OAuth request payload.
"""

mock_requests_post.return_value = Mock(
status_code=200,
json=lambda: {"access_token": "dummy_token"},
)

connection_kwargs = {
**BASE_CONNECTION_KWARGS,
"login": "client_id",
"password": "client_secret",
"extra": {
"account": "airflow",
"authenticator": "oauth",
"grant_type": "client_credentials",
"scope": "default",
},
}

with mock.patch.dict(
"os.environ",
{"AIRFLOW_CONN_TEST_CONN": Connection(**connection_kwargs).get_uri()},
):
hook = SnowflakeHook(snowflake_conn_id="test_conn")
token = hook.get_oauth_token(grant_type="client_credentials")

assert token == "dummy_token"

called_data = mock_requests_post.call_args.kwargs["data"]

assert called_data["scope"] == "default"
assert called_data["grant_type"] == "client_credentials"

@mock.patch("requests.post")
def test_get_oauth_token_without_scope(self, mock_requests_post):
"""
Verify that `get_oauth_token` returns an access token and sends `scope=None`
when no scope is defined in the connection extras.
"""
mock_requests_post.return_value = Mock(
status_code=200,
json=lambda: {"access_token": "dummy_token"},
)

connection_kwargs = {
**BASE_CONNECTION_KWARGS,
"login": "client_id",
"password": "client_secret",
"extra": {"account": "airflow", "authenticator": "oauth", "grant_type": "client_credentials"},
}

with mock.patch.dict(
"os.environ",
{"AIRFLOW_CONN_TEST_CONN": Connection(**connection_kwargs).get_uri()},
):
hook = SnowflakeHook(snowflake_conn_id="test_conn")
token = hook.get_oauth_token(grant_type="client_credentials")

assert token == "dummy_token"

called_data = mock_requests_post.call_args.kwargs["data"]

assert "scope" not in called_data
assert called_data["grant_type"] == "client_credentials"