Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions providers/snowflake/docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,24 @@
Changelog
---------

main
.....

Bug fixes
~~~~~~~~~

.. note::
``private_key_content`` in Snowflake connection should now be base64 encoded. To encode your private key, you can use the following Python snippet:

.. code-block:: python

import base64

with open("path/to/your/private_key.pem", "rb") as key_file:
encoded_key = base64.b64encode(key_file.read()).decode("utf-8")
print(encoded_key)


6.2.2
.....

Expand Down
10 changes: 9 additions & 1 deletion providers/snowflake/docs/connections/snowflake.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,15 @@ Extra (optional)
* ``authenticator``: To connect using OAuth set this parameter ``oauth``.
* ``refresh_token``: Specify refresh_token for OAuth connection.
* ``private_key_file``: Specify the path to the private key file.
* ``private_key_content``: Specify the content of the private key file.
* ``private_key_content``: Specify the content of the private key file in base64 encoded format. You can use the following Python code to encode the private key:

.. code-block:: python

import base64

with open("path/to/private_key.pem", "rb") as key_file:
private_key_content = base64.b64encode(key_file.read()).decode("utf-8")
print(private_key_content)
* ``session_parameters``: Specify `session level parameters <https://docs.snowflake.com/en/user-guide/python-connector-example.html#setting-session-parameters>`_.
* ``insecure_mode``: Turn off OCSP certificate checks. For details, see: `How To: Turn Off OCSP Checking in Snowflake Client Drivers - Snowflake Community <https://community.snowflake.com/s/article/How-to-turn-off-OCSP-checking-in-Snowflake-client-drivers>`_.
* ``host``: Target Snowflake hostname to connect to (e.g., for local testing with LocalStack).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# under the License.
from __future__ import annotations

import base64
import os
from collections.abc import Iterable, Mapping
from contextlib import closing, contextmanager
Expand Down Expand Up @@ -289,7 +290,7 @@ def _get_conn_params(self) -> dict[str, str | None]:
raise ValueError("The private_key_file size is too big. Please keep it less than 4 KB.")
private_key_pem = Path(private_key_file_path).read_bytes()
elif private_key_content:
private_key_pem = private_key_content.encode()
private_key_pem = base64.b64decode(private_key_content)

if private_key_pem:
passphrase = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations

import base64
import uuid
from datetime import timedelta
from pathlib import Path
Expand Down Expand Up @@ -120,7 +121,7 @@ def get_private_key(self) -> None:
if private_key_file:
private_key_pem = Path(private_key_file).read_bytes()
elif private_key_content:
private_key_pem = private_key_content.encode()
private_key_pem = base64.b64decode(private_key_content)

if private_key_pem:
passphrase = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# under the License.
from __future__ import annotations

import base64
import json
import sys
from copy import deepcopy
Expand Down Expand Up @@ -69,7 +70,7 @@


@pytest.fixture
def non_encrypted_temporary_private_key(tmp_path: Path) -> Path:
def unencrypted_temporary_private_key(tmp_path: Path) -> Path:
key = rsa.generate_private_key(backend=default_backend(), public_exponent=65537, key_size=2048)
private_key = key.private_bytes(
serialization.Encoding.PEM, serialization.PrivateFormat.PKCS8, serialization.NoEncryption()
Expand All @@ -79,6 +80,11 @@ def non_encrypted_temporary_private_key(tmp_path: Path) -> Path:
return test_key_file


@pytest.fixture
def base64_encoded_unencrypted_private_key(self, unencrypted_temporary_private_key: Path) -> str:
return base64.b64encode(unencrypted_temporary_private_key.read_bytes()).decode("utf-8")


@pytest.fixture
def encrypted_temporary_private_key(tmp_path: Path) -> Path:
key = rsa.generate_private_key(backend=default_backend(), public_exponent=65537, key_size=2048)
Expand All @@ -92,6 +98,11 @@ def encrypted_temporary_private_key(tmp_path: Path) -> Path:
return test_key_file


@pytest.fixture
def base64_encoded_encrypted_private_key(encrypted_temporary_private_key: Path) -> str:
return base64.b64encode(encrypted_temporary_private_key.read_bytes()).decode("utf-8")


class TestPytestSnowflakeHook:
@pytest.mark.parametrize(
"connection_kwargs,expected_uri,expected_conn_params",
Expand Down Expand Up @@ -358,7 +369,7 @@ def test_hook_should_support_prepare_basic_conn_params_and_uri(
assert SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params == expected_conn_params

def test_get_conn_params_should_support_private_auth_in_connection(
self, encrypted_temporary_private_key: Path
self, base64_encoded_encrypted_private_key: Path
):
connection_kwargs: Any = {
**BASE_CONNECTION_KWARGS,
Expand All @@ -369,7 +380,7 @@ def test_get_conn_params_should_support_private_auth_in_connection(
"warehouse": "af_wh",
"region": "af_region",
"role": "af_role",
"private_key_content": str(encrypted_temporary_private_key.read_text()),
"private_key_content": base64_encoded_encrypted_private_key,
},
}
with mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()):
Expand Down Expand Up @@ -454,7 +465,7 @@ def test_get_conn_params_should_support_private_auth_with_encrypted_key(
assert "private_key" in SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params

def test_get_conn_params_should_support_private_auth_with_unencrypted_key(
self, non_encrypted_temporary_private_key
self, unencrypted_temporary_private_key
):
connection_kwargs = {
**BASE_CONNECTION_KWARGS,
Expand All @@ -465,7 +476,7 @@ def test_get_conn_params_should_support_private_auth_with_unencrypted_key(
"warehouse": "af_wh",
"region": "af_region",
"role": "af_role",
"private_key_file": str(non_encrypted_temporary_private_key),
"private_key_file": str(unencrypted_temporary_private_key),
},
}
with mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()):
Expand Down Expand Up @@ -620,10 +631,10 @@ def test_get_sqlalchemy_engine_should_support_session_parameters(self):
)
assert mock_create_engine.return_value == conn

def test_get_sqlalchemy_engine_should_support_private_key_auth(self, non_encrypted_temporary_private_key):
def test_get_sqlalchemy_engine_should_support_private_key_auth(self, unencrypted_temporary_private_key):
connection_kwargs = deepcopy(BASE_CONNECTION_KWARGS)
connection_kwargs["password"] = ""
connection_kwargs["extra"]["private_key_file"] = str(non_encrypted_temporary_private_key)
connection_kwargs["extra"]["private_key_file"] = str(unencrypted_temporary_private_key)

with (
mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations

import base64
import unittest
import uuid
from typing import TYPE_CHECKING, Any
Expand Down Expand Up @@ -377,7 +378,7 @@ def test_get_oauth_token(self, mock_conn_param, requests_post, mock_auth):
)

@pytest.fixture
def non_encrypted_temporary_private_key(self, tmp_path: Path) -> Path:
def unencrypted_temporary_private_key(self, tmp_path: Path) -> Path:
"""Encrypt the pem file from the path"""
key = rsa.generate_private_key(backend=default_backend(), public_exponent=65537, key_size=2048)
private_key = key.private_bytes(
Expand All @@ -387,6 +388,10 @@ def non_encrypted_temporary_private_key(self, tmp_path: Path) -> Path:
test_key_file.write_bytes(private_key)
return test_key_file

@pytest.fixture
def base64_encoded_unencrypted_private_key(self, unencrypted_temporary_private_key: Path) -> str:
return base64.b64encode(unencrypted_temporary_private_key.read_bytes()).decode("utf-8")

@pytest.fixture
def encrypted_temporary_private_key(self, tmp_path: Path) -> Path:
"""Encrypt private key from the temp path"""
Expand All @@ -400,8 +405,12 @@ def encrypted_temporary_private_key(self, tmp_path: Path) -> Path:
test_key_file.write_bytes(private_key)
return test_key_file

@pytest.fixture
def base64_encoded_encrypted_private_key(self, encrypted_temporary_private_key: Path) -> str:
return base64.b64encode(encrypted_temporary_private_key.read_bytes()).decode("utf-8")

def test_get_private_key_should_support_private_auth_in_connection(
self, encrypted_temporary_private_key: Path
self, base64_encoded_encrypted_private_key: str
):
"""Test get_private_key function with private_key_content in connection"""
connection_kwargs: Any = {
Expand All @@ -413,7 +422,7 @@ def test_get_private_key_should_support_private_auth_in_connection(
"warehouse": "af_wh",
"region": "af_region",
"role": "af_role",
"private_key_content": str(encrypted_temporary_private_key.read_text()),
"private_key_content": base64_encoded_encrypted_private_key,
},
}
with unittest.mock.patch.dict(
Expand All @@ -423,7 +432,9 @@ def test_get_private_key_should_support_private_auth_in_connection(
hook.get_private_key()
assert hook.private_key is not None

def test_get_private_key_raise_exception(self, encrypted_temporary_private_key: Path):
def test_get_private_key_raise_exception(
self, encrypted_temporary_private_key: Path, base64_encoded_encrypted_private_key: str
):
"""
Test get_private_key function with private_key_content and private_key_file in connection
and raise airflow exception
Expand All @@ -437,7 +448,7 @@ def test_get_private_key_raise_exception(self, encrypted_temporary_private_key:
"warehouse": "af_wh",
"region": "af_region",
"role": "af_role",
"private_key_content": str(encrypted_temporary_private_key.read_text()),
"private_key_content": base64_encoded_encrypted_private_key,
"private_key_file": str(encrypted_temporary_private_key),
},
}
Expand Down Expand Up @@ -479,7 +490,7 @@ def test_get_private_key_should_support_private_auth_with_encrypted_key(

def test_get_private_key_should_support_private_auth_with_unencrypted_key(
self,
non_encrypted_temporary_private_key,
unencrypted_temporary_private_key,
):
connection_kwargs = {
**BASE_CONNECTION_KWARGS,
Expand All @@ -490,7 +501,7 @@ def test_get_private_key_should_support_private_auth_with_unencrypted_key(
"warehouse": "af_wh",
"region": "af_region",
"role": "af_role",
"private_key_file": str(non_encrypted_temporary_private_key),
"private_key_file": str(unencrypted_temporary_private_key),
},
}
with unittest.mock.patch.dict(
Expand Down