Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions providers/snowflake/docs/configurations-ref.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
.. Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at

.. http://www.apache.org/licenses/LICENSE-2.0

.. Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.

.. include:: /../../../devel-common/src/sphinx_exts/includes/providers-configurations-ref.rst
.. include:: /../../../devel-common/src/sphinx_exts/includes/sections-and-options.rst
1 change: 1 addition & 0 deletions providers/snowflake/docs/connections/snowflake.rst
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ Extra (optional)
* ``token_endpoint``: Specify token endpoint for external OAuth provider.
* ``grant_type``: Specify grant type for OAuth authentication. Currently supported: ``refresh_token`` (default), ``client_credentials``.
* ``refresh_token``: Specify refresh_token for OAuth connection.
* ``azure_conn_id``: Azure Connection ID to be used for retrieving the OAuth token using Azure Entra authentication. Login and Password fields aren't required when using this method. Scope for the Azure OAuth token can be set in the config option ``azure_oauth_scope`` under the section ``[snowflake]``. Requires `apache-airflow-providers-microsoft-azure>=12.8.0`.
* ``private_key_file``: Specify the path to the private key file.
* ``private_key_content``: Specify the content of the private key file in base64 encoded format. You can use the following Python code to encode the private key:

Expand Down
16 changes: 9 additions & 7 deletions providers/snowflake/docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
:maxdepth: 1
:caption: References

Configuration <configurations-ref>
Python API <_api/airflow/providers/snowflake/index>

.. toctree::
Expand Down Expand Up @@ -127,13 +128,14 @@ You can install such cross-provider dependencies when installing from PyPI. For
pip install apache-airflow-providers-snowflake[common.compat]


================================================================================================================== =================
Dependent package Extra
================================================================================================================== =================
`apache-airflow-providers-common-compat <https://airflow.apache.org/docs/apache-airflow-providers-common-compat>`_ ``common.compat``
`apache-airflow-providers-common-sql <https://airflow.apache.org/docs/apache-airflow-providers-common-sql>`_ ``common.sql``
`apache-airflow-providers-openlineage <https://airflow.apache.org/docs/apache-airflow-providers-openlineage>`_ ``openlineage``
================================================================================================================== =================
====================================================================================================================== =================
Dependent package Extra
====================================================================================================================== =================
`apache-airflow-providers-common-compat <https://airflow.apache.org/docs/apache-airflow-providers-common-compat>`_ ``common.compat``
`apache-airflow-providers-common-sql <https://airflow.apache.org/docs/apache-airflow-providers-common-sql>`_ ``common.sql``
`apache-airflow-providers-openlineage <https://airflow.apache.org/docs/apache-airflow-providers-openlineage>`_ ``openlineage``
`apache-airflow-providers-microsoft-azure <https://airflow.apache.org/docs/apache-airflow-providers-microsoft-azure>`_ ``microsoft.azure``
====================================================================================================================== =================

Downloading official packages
-----------------------------
Expand Down
13 changes: 13 additions & 0 deletions providers/snowflake/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,16 @@ triggers:
- integration-name: Snowflake
python-modules:
- airflow.providers.snowflake.triggers.snowflake_trigger

config:
snowflake:
description: |
Configuration for Snowflake hooks and operators.
options:
azure_oauth_scope:
description: |
The scope to use while retrieving OAuth token for Snowflake from Azure Entra authentication.
version_added: 6.6.0
type: string
example: ~
default: "api://snowflake_oauth_server/.default"
4 changes: 4 additions & 0 deletions providers/snowflake/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ dependencies = [
# The optional dependencies should be modified in place in the generated file
# Any change in the dependencies is preserved when the file is regenerated
[project.optional-dependencies]
"microsoft.azure" = [
"apache-airflow-providers-microsoft-azure"
]
"openlineage" = [
"apache-airflow-providers-openlineage>=2.3.0"
]
Expand All @@ -86,6 +89,7 @@ dev = [
"apache-airflow-devel-common",
"apache-airflow-providers-common-compat",
"apache-airflow-providers-common-sql",
"apache-airflow-providers-microsoft-azure",
"apache-airflow-providers-openlineage",
# Additional devel dependencies (do not remove this line and add extra development dependencies)
"responses>=0.25.0",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,4 +94,18 @@ def get_provider_info():
"python-modules": ["airflow.providers.snowflake.triggers.snowflake_trigger"],
}
],
"config": {
"snowflake": {
"description": "Configuration for Snowflake hooks and operators.\n",
"options": {
"azure_oauth_scope": {
"description": "The scope to use while retrieving OAuth token for Snowflake from Azure Entra authentication.\n",
"version_added": "6.6.0",
"type": "string",
"example": None,
"default": "api://snowflake_oauth_server/.default",
}
},
}
},
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,18 @@
from snowflake.sqlalchemy import URL
from sqlalchemy import create_engine

from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.providers.common.sql.hooks.handlers import return_single_query_results
from airflow.providers.common.sql.hooks.sql import DbApiHook
from airflow.providers.snowflake.utils.openlineage import fix_snowflake_sqlalchemy_uri
from airflow.utils.strings import to_boolean

try:
from airflow.sdk import Connection
except ImportError:
from airflow.models.connection import Connection # type: ignore[assignment]

T = TypeVar("T")
if TYPE_CHECKING:
from airflow.providers.openlineage.extractors import OperatorLineage
Expand Down Expand Up @@ -94,6 +100,7 @@ class SnowflakeHook(DbApiHook):
hook_name = "Snowflake"
supports_autocommit = True
_test_connection_sql = "select 1"
default_azure_oauth_scope = "api://snowflake_oauth_server/.default"

@classmethod
def get_connection_form_widgets(cls) -> dict[str, Any]:
Expand Down Expand Up @@ -246,6 +253,40 @@ def get_oauth_token(
token = response.json()["access_token"]
return token

def get_azure_oauth_token(self, azure_conn_id: str) -> str:
"""
Generate OAuth access token using Azure connection id.

This uses AzureBaseHook on the connection id to retrieve the token. Scope for the OAuth token can be
set in the config option ``azure_oauth_scope`` under the section ``[snowflake]``.

:param azure_conn_id: The connection id for the Azure connection that will be used to fetch the token.
:raises AttributeError: If AzureBaseHook does not have a get_token method which happens when
package apache-airflow-providers-microsoft-azure<12.8.0.
:returns: The OAuth access token string.
"""
if TYPE_CHECKING:
from airflow.providers.microsoft.azure.hooks.azure_base import AzureBaseHook

try:
azure_conn = Connection.get(azure_conn_id)
except AttributeError:
azure_conn = Connection.get_connection_from_secrets(azure_conn_id) # type: ignore[attr-defined]
azure_base_hook: AzureBaseHook = azure_conn.get_hook()
scope = conf.get("snowflake", "azure_oauth_scope", fallback=self.default_azure_oauth_scope)
try:
token = azure_base_hook.get_token(scope).token
except AttributeError as e:
if e.name == "get_token" and e.obj == azure_base_hook:
raise AttributeError(
"'AzureBaseHook' object has no attribute 'get_token'. "
"Please upgrade apache-airflow-providers-microsoft-azure>=12.8.0",
name=e.name,
obj=e.obj,
) from e
raise
return token

@cached_property
def _get_conn_params(self) -> dict[str, str | None]:
"""
Expand Down Expand Up @@ -349,14 +390,17 @@ def _get_conn_params(self) -> dict[str, str | None]:
conn_config["authenticator"] = "oauth"

if conn_config.get("authenticator") == "oauth":
token_endpoint = self._get_field(extra_dict, "token_endpoint") or ""
conn_config["client_id"] = conn.login
conn_config["client_secret"] = conn.password
conn_config["token"] = self.get_oauth_token(
conn_config=conn_config,
token_endpoint=token_endpoint,
grant_type=extra_dict.get("grant_type", "refresh_token"),
)
if extra_dict.get("azure_conn_id"):
conn_config["token"] = self.get_azure_oauth_token(extra_dict["azure_conn_id"])
else:
token_endpoint = self._get_field(extra_dict, "token_endpoint") or ""
conn_config["client_id"] = conn.login
conn_config["client_secret"] = conn.password
conn_config["token"] = self.get_oauth_token(
conn_config=conn_config,
token_endpoint=token_endpoint,
grant_type=extra_dict.get("grant_type", "refresh_token"),
)

conn_config.pop("login", None)
conn_config.pop("user", None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,44 @@ def test_get_conn_params_should_support_oauth_with_client_credentials(
assert "region" in conn_params_extra_keys
assert "account" in conn_params_extra_keys

def test_get_conn_params_should_support_oauth_with_azure_conn_id(self, mocker):
azure_conn_id = "azure_test_conn"
mock_azure_token = "azure_test_token"
connection_kwargs = {
"extra": {
"database": "db",
"account": "airflow",
"region": "af_region",
"warehouse": "af_wh",
"authenticator": "oauth",
"azure_conn_id": azure_conn_id,
},
}

mock_connection_class = mocker.patch("airflow.providers.snowflake.hooks.snowflake.Connection")
mock_azure_base_hook = mock_connection_class.get.return_value.get_hook.return_value
mock_azure_base_hook.get_token.return_value.token = mock_azure_token

with mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()):
hook = SnowflakeHook(snowflake_conn_id="test_conn")
conn_params = hook._get_conn_params

# Check AzureBaseHook initialization and get_token call args
mock_connection_class.get.assert_called_once_with(azure_conn_id)
mock_azure_base_hook.get_token.assert_called_once_with(SnowflakeHook.default_azure_oauth_scope)

assert "authenticator" in conn_params
assert conn_params["authenticator"] == "oauth"
assert "token" in conn_params
assert conn_params["token"] == mock_azure_token

assert "user" not in conn_params
assert "password" not in conn_params
assert "refresh_token" not in conn_params
# Mandatory fields to generate account_identifier `https://<account>.<region>`
assert "region" in conn_params
assert "account" in conn_params

def test_should_add_partner_info(self):
with mock.patch.dict(
"os.environ",
Expand Down Expand Up @@ -1054,3 +1092,44 @@ def test_get_oauth_token_with_token_endpoint(self, mock_conn_param, requests_pos
headers={"Content-Type": "application/x-www-form-urlencoded"},
auth=basic_auth,
)

def test_get_azure_oauth_token(self, mocker):
"""Test get_azure_oauth_token method gets token from provided connection id"""
azure_conn_id = "azure_test_conn"
mock_azure_token = "azure_test_token"

mock_connection_class = mocker.patch("airflow.providers.snowflake.hooks.snowflake.Connection")
mock_azure_base_hook = mock_connection_class.get.return_value.get_hook.return_value
mock_azure_base_hook.get_token.return_value.token = mock_azure_token

hook = SnowflakeHook(snowflake_conn_id="mock_conn_id")
token = hook.get_azure_oauth_token(azure_conn_id)

# Check AzureBaseHook initialization and get_token call args
mock_connection_class.get.assert_called_once_with(azure_conn_id)
mock_azure_base_hook.get_token.assert_called_once_with(SnowflakeHook.default_azure_oauth_scope)
assert token == mock_azure_token

def test_get_azure_oauth_token_expect_failure_on_get_token(self, mocker):
"""Test get_azure_oauth_token method gets token from provided connection id"""

class MockAzureBaseHookWithoutGetToken:
def __init__(self):
pass

azure_conn_id = "azure_test_conn"
mock_connection_class = mocker.patch("airflow.providers.snowflake.hooks.snowflake.Connection")
mock_connection_class.get.return_value.get_hook.return_value = MockAzureBaseHookWithoutGetToken()

hook = SnowflakeHook(snowflake_conn_id="mock_conn_id")
with pytest.raises(
AttributeError,
match=(
"'AzureBaseHook' object has no attribute 'get_token'. "
"Please upgrade apache-airflow-providers-microsoft-azure>="
),
):
hook.get_azure_oauth_token(azure_conn_id)

# Check AzureBaseHook initialization
mock_connection_class.get.assert_called_once_with(azure_conn_id)
Loading