diff --git a/providers/snowflake/docs/changelog.rst b/providers/snowflake/docs/changelog.rst index 5b77778b03fa1..fc59ea33593dc 100644 --- a/providers/snowflake/docs/changelog.rst +++ b/providers/snowflake/docs/changelog.rst @@ -27,6 +27,24 @@ Changelog --------- +main +..... + +Bug fixes +~~~~~~~~~ + +.. note:: + ``private_key_content`` in Snowflake connection should now be base64 encoded. To encode your private key, you can use the following Python snippet: + + .. code-block:: python + + import base64 + + with open("path/to/your/private_key.pem", "rb") as key_file: + encoded_key = base64.b64encode(key_file.read()).decode("utf-8") + print(encoded_key) + + 6.2.2 ..... diff --git a/providers/snowflake/docs/connections/snowflake.rst b/providers/snowflake/docs/connections/snowflake.rst index 903bc17c5da9e..9c8c6a4a87576 100644 --- a/providers/snowflake/docs/connections/snowflake.rst +++ b/providers/snowflake/docs/connections/snowflake.rst @@ -60,7 +60,15 @@ Extra (optional) * ``authenticator``: To connect using OAuth set this parameter ``oauth``. * ``refresh_token``: Specify refresh_token for OAuth connection. * ``private_key_file``: Specify the path to the private key file. - * ``private_key_content``: Specify the content of the private key file. + * ``private_key_content``: Specify the content of the private key file in base64 encoded format. You can use the following Python code to encode the private key: + + .. code-block:: python + + import base64 + + with open("path/to/private_key.pem", "rb") as key_file: + private_key_content = base64.b64encode(key_file.read()).decode("utf-8") + print(private_key_content) * ``session_parameters``: Specify `session level parameters `_. * ``insecure_mode``: Turn off OCSP certificate checks. For details, see: `How To: Turn Off OCSP Checking in Snowflake Client Drivers - Snowflake Community `_. * ``host``: Target Snowflake hostname to connect to (e.g., for local testing with LocalStack). diff --git a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py index 9796638250135..f72d7ed5b8c87 100644 --- a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py +++ b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py @@ -17,6 +17,7 @@ # under the License. from __future__ import annotations +import base64 import os from collections.abc import Iterable, Mapping from contextlib import closing, contextmanager @@ -289,7 +290,7 @@ def _get_conn_params(self) -> dict[str, str | None]: raise ValueError("The private_key_file size is too big. Please keep it less than 4 KB.") private_key_pem = Path(private_key_file_path).read_bytes() elif private_key_content: - private_key_pem = private_key_content.encode() + private_key_pem = base64.b64decode(private_key_content) if private_key_pem: passphrase = None diff --git a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py index 1fe3e6cd83cc5..c5ee2989cc9ae 100644 --- a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py +++ b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py @@ -16,6 +16,7 @@ # under the License. from __future__ import annotations +import base64 import uuid from datetime import timedelta from pathlib import Path @@ -120,7 +121,7 @@ def get_private_key(self) -> None: if private_key_file: private_key_pem = Path(private_key_file).read_bytes() elif private_key_content: - private_key_pem = private_key_content.encode() + private_key_pem = base64.b64decode(private_key_content) if private_key_pem: passphrase = None diff --git a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py index b682e2e851190..30fbd50289b6c 100644 --- a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py +++ b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py @@ -17,6 +17,7 @@ # under the License. from __future__ import annotations +import base64 import json import sys from copy import deepcopy @@ -69,7 +70,7 @@ @pytest.fixture -def non_encrypted_temporary_private_key(tmp_path: Path) -> Path: +def unencrypted_temporary_private_key(tmp_path: Path) -> Path: key = rsa.generate_private_key(backend=default_backend(), public_exponent=65537, key_size=2048) private_key = key.private_bytes( serialization.Encoding.PEM, serialization.PrivateFormat.PKCS8, serialization.NoEncryption() @@ -79,6 +80,11 @@ def non_encrypted_temporary_private_key(tmp_path: Path) -> Path: return test_key_file +@pytest.fixture +def base64_encoded_unencrypted_private_key(self, unencrypted_temporary_private_key: Path) -> str: + return base64.b64encode(unencrypted_temporary_private_key.read_bytes()).decode("utf-8") + + @pytest.fixture def encrypted_temporary_private_key(tmp_path: Path) -> Path: key = rsa.generate_private_key(backend=default_backend(), public_exponent=65537, key_size=2048) @@ -92,6 +98,11 @@ def encrypted_temporary_private_key(tmp_path: Path) -> Path: return test_key_file +@pytest.fixture +def base64_encoded_encrypted_private_key(encrypted_temporary_private_key: Path) -> str: + return base64.b64encode(encrypted_temporary_private_key.read_bytes()).decode("utf-8") + + class TestPytestSnowflakeHook: @pytest.mark.parametrize( "connection_kwargs,expected_uri,expected_conn_params", @@ -358,7 +369,7 @@ def test_hook_should_support_prepare_basic_conn_params_and_uri( assert SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params == expected_conn_params def test_get_conn_params_should_support_private_auth_in_connection( - self, encrypted_temporary_private_key: Path + self, base64_encoded_encrypted_private_key: Path ): connection_kwargs: Any = { **BASE_CONNECTION_KWARGS, @@ -369,7 +380,7 @@ def test_get_conn_params_should_support_private_auth_in_connection( "warehouse": "af_wh", "region": "af_region", "role": "af_role", - "private_key_content": str(encrypted_temporary_private_key.read_text()), + "private_key_content": base64_encoded_encrypted_private_key, }, } with mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()): @@ -454,7 +465,7 @@ def test_get_conn_params_should_support_private_auth_with_encrypted_key( assert "private_key" in SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params def test_get_conn_params_should_support_private_auth_with_unencrypted_key( - self, non_encrypted_temporary_private_key + self, unencrypted_temporary_private_key ): connection_kwargs = { **BASE_CONNECTION_KWARGS, @@ -465,7 +476,7 @@ def test_get_conn_params_should_support_private_auth_with_unencrypted_key( "warehouse": "af_wh", "region": "af_region", "role": "af_role", - "private_key_file": str(non_encrypted_temporary_private_key), + "private_key_file": str(unencrypted_temporary_private_key), }, } with mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()): @@ -620,10 +631,10 @@ def test_get_sqlalchemy_engine_should_support_session_parameters(self): ) assert mock_create_engine.return_value == conn - def test_get_sqlalchemy_engine_should_support_private_key_auth(self, non_encrypted_temporary_private_key): + def test_get_sqlalchemy_engine_should_support_private_key_auth(self, unencrypted_temporary_private_key): connection_kwargs = deepcopy(BASE_CONNECTION_KWARGS) connection_kwargs["password"] = "" - connection_kwargs["extra"]["private_key_file"] = str(non_encrypted_temporary_private_key) + connection_kwargs["extra"]["private_key_file"] = str(unencrypted_temporary_private_key) with ( mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()), diff --git a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake_sql_api.py b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake_sql_api.py index efb47f0c9d213..d90c38098f98b 100644 --- a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake_sql_api.py +++ b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake_sql_api.py @@ -16,6 +16,7 @@ # under the License. from __future__ import annotations +import base64 import unittest import uuid from typing import TYPE_CHECKING, Any @@ -377,7 +378,7 @@ def test_get_oauth_token(self, mock_conn_param, requests_post, mock_auth): ) @pytest.fixture - def non_encrypted_temporary_private_key(self, tmp_path: Path) -> Path: + def unencrypted_temporary_private_key(self, tmp_path: Path) -> Path: """Encrypt the pem file from the path""" key = rsa.generate_private_key(backend=default_backend(), public_exponent=65537, key_size=2048) private_key = key.private_bytes( @@ -387,6 +388,10 @@ def non_encrypted_temporary_private_key(self, tmp_path: Path) -> Path: test_key_file.write_bytes(private_key) return test_key_file + @pytest.fixture + def base64_encoded_unencrypted_private_key(self, unencrypted_temporary_private_key: Path) -> str: + return base64.b64encode(unencrypted_temporary_private_key.read_bytes()).decode("utf-8") + @pytest.fixture def encrypted_temporary_private_key(self, tmp_path: Path) -> Path: """Encrypt private key from the temp path""" @@ -400,8 +405,12 @@ def encrypted_temporary_private_key(self, tmp_path: Path) -> Path: test_key_file.write_bytes(private_key) return test_key_file + @pytest.fixture + def base64_encoded_encrypted_private_key(self, encrypted_temporary_private_key: Path) -> str: + return base64.b64encode(encrypted_temporary_private_key.read_bytes()).decode("utf-8") + def test_get_private_key_should_support_private_auth_in_connection( - self, encrypted_temporary_private_key: Path + self, base64_encoded_encrypted_private_key: str ): """Test get_private_key function with private_key_content in connection""" connection_kwargs: Any = { @@ -413,7 +422,7 @@ def test_get_private_key_should_support_private_auth_in_connection( "warehouse": "af_wh", "region": "af_region", "role": "af_role", - "private_key_content": str(encrypted_temporary_private_key.read_text()), + "private_key_content": base64_encoded_encrypted_private_key, }, } with unittest.mock.patch.dict( @@ -423,7 +432,9 @@ def test_get_private_key_should_support_private_auth_in_connection( hook.get_private_key() assert hook.private_key is not None - def test_get_private_key_raise_exception(self, encrypted_temporary_private_key: Path): + def test_get_private_key_raise_exception( + self, encrypted_temporary_private_key: Path, base64_encoded_encrypted_private_key: str + ): """ Test get_private_key function with private_key_content and private_key_file in connection and raise airflow exception @@ -437,7 +448,7 @@ def test_get_private_key_raise_exception(self, encrypted_temporary_private_key: "warehouse": "af_wh", "region": "af_region", "role": "af_role", - "private_key_content": str(encrypted_temporary_private_key.read_text()), + "private_key_content": base64_encoded_encrypted_private_key, "private_key_file": str(encrypted_temporary_private_key), }, } @@ -479,7 +490,7 @@ def test_get_private_key_should_support_private_auth_with_encrypted_key( def test_get_private_key_should_support_private_auth_with_unencrypted_key( self, - non_encrypted_temporary_private_key, + unencrypted_temporary_private_key, ): connection_kwargs = { **BASE_CONNECTION_KWARGS, @@ -490,7 +501,7 @@ def test_get_private_key_should_support_private_auth_with_unencrypted_key( "warehouse": "af_wh", "region": "af_region", "role": "af_role", - "private_key_file": str(non_encrypted_temporary_private_key), + "private_key_file": str(unencrypted_temporary_private_key), }, } with unittest.mock.patch.dict(