diff --git a/providers/snowflake/docs/configurations-ref.rst b/providers/snowflake/docs/configurations-ref.rst new file mode 100644 index 0000000000000..a52b21b2e5679 --- /dev/null +++ b/providers/snowflake/docs/configurations-ref.rst @@ -0,0 +1,19 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +.. include:: /../../../devel-common/src/sphinx_exts/includes/providers-configurations-ref.rst +.. include:: /../../../devel-common/src/sphinx_exts/includes/sections-and-options.rst diff --git a/providers/snowflake/docs/connections/snowflake.rst b/providers/snowflake/docs/connections/snowflake.rst index e15280ceb773f..3a523d97b4ee3 100644 --- a/providers/snowflake/docs/connections/snowflake.rst +++ b/providers/snowflake/docs/connections/snowflake.rst @@ -61,6 +61,7 @@ Extra (optional) * ``token_endpoint``: Specify token endpoint for external OAuth provider. * ``grant_type``: Specify grant type for OAuth authentication. Currently supported: ``refresh_token`` (default), ``client_credentials``. * ``refresh_token``: Specify refresh_token for OAuth connection. + * ``azure_conn_id``: Azure Connection ID to be used for retrieving the OAuth token using Azure Entra authentication. Login and Password fields aren't required when using this method. Scope for the Azure OAuth token can be set in the config option ``azure_oauth_scope`` under the section ``[snowflake]``. Requires `apache-airflow-providers-microsoft-azure>=12.8.0`. * ``private_key_file``: Specify the path to the private key file. * ``private_key_content``: Specify the content of the private key file in base64 encoded format. You can use the following Python code to encode the private key: diff --git a/providers/snowflake/docs/index.rst b/providers/snowflake/docs/index.rst index 0bf5c4e0e6a3d..cfa234ade6022 100644 --- a/providers/snowflake/docs/index.rst +++ b/providers/snowflake/docs/index.rst @@ -43,6 +43,7 @@ :maxdepth: 1 :caption: References + Configuration Python API <_api/airflow/providers/snowflake/index> .. toctree:: @@ -127,13 +128,14 @@ You can install such cross-provider dependencies when installing from PyPI. For pip install apache-airflow-providers-snowflake[common.compat] -================================================================================================================== ================= -Dependent package Extra -================================================================================================================== ================= -`apache-airflow-providers-common-compat `_ ``common.compat`` -`apache-airflow-providers-common-sql `_ ``common.sql`` -`apache-airflow-providers-openlineage `_ ``openlineage`` -================================================================================================================== ================= +====================================================================================================================== ================= +Dependent package Extra +====================================================================================================================== ================= +`apache-airflow-providers-common-compat `_ ``common.compat`` +`apache-airflow-providers-common-sql `_ ``common.sql`` +`apache-airflow-providers-openlineage `_ ``openlineage`` +`apache-airflow-providers-microsoft-azure `_ ``microsoft.azure`` +====================================================================================================================== ================= Downloading official packages ----------------------------- diff --git a/providers/snowflake/provider.yaml b/providers/snowflake/provider.yaml index 3fc0798fd1ebe..e399ef46c6e3e 100644 --- a/providers/snowflake/provider.yaml +++ b/providers/snowflake/provider.yaml @@ -143,3 +143,16 @@ triggers: - integration-name: Snowflake python-modules: - airflow.providers.snowflake.triggers.snowflake_trigger + +config: + snowflake: + description: | + Configuration for Snowflake hooks and operators. + options: + azure_oauth_scope: + description: | + The scope to use while retrieving OAuth token for Snowflake from Azure Entra authentication. + version_added: 6.6.0 + type: string + example: ~ + default: "api://snowflake_oauth_server/.default" diff --git a/providers/snowflake/pyproject.toml b/providers/snowflake/pyproject.toml index ea430e463ef00..2924a3d16bd1a 100644 --- a/providers/snowflake/pyproject.toml +++ b/providers/snowflake/pyproject.toml @@ -75,6 +75,9 @@ dependencies = [ # The optional dependencies should be modified in place in the generated file # Any change in the dependencies is preserved when the file is regenerated [project.optional-dependencies] +"microsoft.azure" = [ + "apache-airflow-providers-microsoft-azure" +] "openlineage" = [ "apache-airflow-providers-openlineage>=2.3.0" ] @@ -86,6 +89,7 @@ dev = [ "apache-airflow-devel-common", "apache-airflow-providers-common-compat", "apache-airflow-providers-common-sql", + "apache-airflow-providers-microsoft-azure", "apache-airflow-providers-openlineage", # Additional devel dependencies (do not remove this line and add extra development dependencies) "responses>=0.25.0", diff --git a/providers/snowflake/src/airflow/providers/snowflake/get_provider_info.py b/providers/snowflake/src/airflow/providers/snowflake/get_provider_info.py index fe76f4ec64efc..03ab85efa0482 100644 --- a/providers/snowflake/src/airflow/providers/snowflake/get_provider_info.py +++ b/providers/snowflake/src/airflow/providers/snowflake/get_provider_info.py @@ -94,4 +94,18 @@ def get_provider_info(): "python-modules": ["airflow.providers.snowflake.triggers.snowflake_trigger"], } ], + "config": { + "snowflake": { + "description": "Configuration for Snowflake hooks and operators.\n", + "options": { + "azure_oauth_scope": { + "description": "The scope to use while retrieving OAuth token for Snowflake from Azure Entra authentication.\n", + "version_added": "6.6.0", + "type": "string", + "example": None, + "default": "api://snowflake_oauth_server/.default", + } + }, + } + }, } diff --git a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py index fd7e28804c19d..f28a7932d8b2e 100644 --- a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py +++ b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py @@ -36,12 +36,18 @@ from snowflake.sqlalchemy import URL from sqlalchemy import create_engine +from airflow.configuration import conf from airflow.exceptions import AirflowException from airflow.providers.common.sql.hooks.handlers import return_single_query_results from airflow.providers.common.sql.hooks.sql import DbApiHook from airflow.providers.snowflake.utils.openlineage import fix_snowflake_sqlalchemy_uri from airflow.utils.strings import to_boolean +try: + from airflow.sdk import Connection +except ImportError: + from airflow.models.connection import Connection # type: ignore[assignment] + T = TypeVar("T") if TYPE_CHECKING: from airflow.providers.openlineage.extractors import OperatorLineage @@ -94,6 +100,7 @@ class SnowflakeHook(DbApiHook): hook_name = "Snowflake" supports_autocommit = True _test_connection_sql = "select 1" + default_azure_oauth_scope = "api://snowflake_oauth_server/.default" @classmethod def get_connection_form_widgets(cls) -> dict[str, Any]: @@ -246,6 +253,40 @@ def get_oauth_token( token = response.json()["access_token"] return token + def get_azure_oauth_token(self, azure_conn_id: str) -> str: + """ + Generate OAuth access token using Azure connection id. + + This uses AzureBaseHook on the connection id to retrieve the token. Scope for the OAuth token can be + set in the config option ``azure_oauth_scope`` under the section ``[snowflake]``. + + :param azure_conn_id: The connection id for the Azure connection that will be used to fetch the token. + :raises AttributeError: If AzureBaseHook does not have a get_token method which happens when + package apache-airflow-providers-microsoft-azure<12.8.0. + :returns: The OAuth access token string. + """ + if TYPE_CHECKING: + from airflow.providers.microsoft.azure.hooks.azure_base import AzureBaseHook + + try: + azure_conn = Connection.get(azure_conn_id) + except AttributeError: + azure_conn = Connection.get_connection_from_secrets(azure_conn_id) # type: ignore[attr-defined] + azure_base_hook: AzureBaseHook = azure_conn.get_hook() + scope = conf.get("snowflake", "azure_oauth_scope", fallback=self.default_azure_oauth_scope) + try: + token = azure_base_hook.get_token(scope).token + except AttributeError as e: + if e.name == "get_token" and e.obj == azure_base_hook: + raise AttributeError( + "'AzureBaseHook' object has no attribute 'get_token'. " + "Please upgrade apache-airflow-providers-microsoft-azure>=12.8.0", + name=e.name, + obj=e.obj, + ) from e + raise + return token + @cached_property def _get_conn_params(self) -> dict[str, str | None]: """ @@ -349,14 +390,17 @@ def _get_conn_params(self) -> dict[str, str | None]: conn_config["authenticator"] = "oauth" if conn_config.get("authenticator") == "oauth": - token_endpoint = self._get_field(extra_dict, "token_endpoint") or "" - conn_config["client_id"] = conn.login - conn_config["client_secret"] = conn.password - conn_config["token"] = self.get_oauth_token( - conn_config=conn_config, - token_endpoint=token_endpoint, - grant_type=extra_dict.get("grant_type", "refresh_token"), - ) + if extra_dict.get("azure_conn_id"): + conn_config["token"] = self.get_azure_oauth_token(extra_dict["azure_conn_id"]) + else: + token_endpoint = self._get_field(extra_dict, "token_endpoint") or "" + conn_config["client_id"] = conn.login + conn_config["client_secret"] = conn.password + conn_config["token"] = self.get_oauth_token( + conn_config=conn_config, + token_endpoint=token_endpoint, + grant_type=extra_dict.get("grant_type", "refresh_token"), + ) conn_config.pop("login", None) conn_config.pop("user", None) diff --git a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py index 1c9e4676daff5..6c7b0df885585 100644 --- a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py +++ b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py @@ -666,6 +666,44 @@ def test_get_conn_params_should_support_oauth_with_client_credentials( assert "region" in conn_params_extra_keys assert "account" in conn_params_extra_keys + def test_get_conn_params_should_support_oauth_with_azure_conn_id(self, mocker): + azure_conn_id = "azure_test_conn" + mock_azure_token = "azure_test_token" + connection_kwargs = { + "extra": { + "database": "db", + "account": "airflow", + "region": "af_region", + "warehouse": "af_wh", + "authenticator": "oauth", + "azure_conn_id": azure_conn_id, + }, + } + + mock_connection_class = mocker.patch("airflow.providers.snowflake.hooks.snowflake.Connection") + mock_azure_base_hook = mock_connection_class.get.return_value.get_hook.return_value + mock_azure_base_hook.get_token.return_value.token = mock_azure_token + + with mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()): + hook = SnowflakeHook(snowflake_conn_id="test_conn") + conn_params = hook._get_conn_params + + # Check AzureBaseHook initialization and get_token call args + mock_connection_class.get.assert_called_once_with(azure_conn_id) + mock_azure_base_hook.get_token.assert_called_once_with(SnowflakeHook.default_azure_oauth_scope) + + assert "authenticator" in conn_params + assert conn_params["authenticator"] == "oauth" + assert "token" in conn_params + assert conn_params["token"] == mock_azure_token + + assert "user" not in conn_params + assert "password" not in conn_params + assert "refresh_token" not in conn_params + # Mandatory fields to generate account_identifier `https://.` + assert "region" in conn_params + assert "account" in conn_params + def test_should_add_partner_info(self): with mock.patch.dict( "os.environ", @@ -1054,3 +1092,44 @@ def test_get_oauth_token_with_token_endpoint(self, mock_conn_param, requests_pos headers={"Content-Type": "application/x-www-form-urlencoded"}, auth=basic_auth, ) + + def test_get_azure_oauth_token(self, mocker): + """Test get_azure_oauth_token method gets token from provided connection id""" + azure_conn_id = "azure_test_conn" + mock_azure_token = "azure_test_token" + + mock_connection_class = mocker.patch("airflow.providers.snowflake.hooks.snowflake.Connection") + mock_azure_base_hook = mock_connection_class.get.return_value.get_hook.return_value + mock_azure_base_hook.get_token.return_value.token = mock_azure_token + + hook = SnowflakeHook(snowflake_conn_id="mock_conn_id") + token = hook.get_azure_oauth_token(azure_conn_id) + + # Check AzureBaseHook initialization and get_token call args + mock_connection_class.get.assert_called_once_with(azure_conn_id) + mock_azure_base_hook.get_token.assert_called_once_with(SnowflakeHook.default_azure_oauth_scope) + assert token == mock_azure_token + + def test_get_azure_oauth_token_expect_failure_on_get_token(self, mocker): + """Test get_azure_oauth_token method gets token from provided connection id""" + + class MockAzureBaseHookWithoutGetToken: + def __init__(self): + pass + + azure_conn_id = "azure_test_conn" + mock_connection_class = mocker.patch("airflow.providers.snowflake.hooks.snowflake.Connection") + mock_connection_class.get.return_value.get_hook.return_value = MockAzureBaseHookWithoutGetToken() + + hook = SnowflakeHook(snowflake_conn_id="mock_conn_id") + with pytest.raises( + AttributeError, + match=( + "'AzureBaseHook' object has no attribute 'get_token'. " + "Please upgrade apache-airflow-providers-microsoft-azure>=" + ), + ): + hook.get_azure_oauth_token(azure_conn_id) + + # Check AzureBaseHook initialization + mock_connection_class.get.assert_called_once_with(azure_conn_id)