Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion providers/snowflake/docs/connections/snowflake.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ Extra (optional)
* ``authenticator``: To connect using OAuth set this parameter ``oauth``.
* ``refresh_token``: Specify refresh_token for OAuth connection.
* ``private_key_file``: Specify the path to the private key file.
* ``private_key_content``: Specify the content of the private key file.
* ``private_key_content``: Specify the content of the private key file. To input a PEM-formatted private key, replace all actual newlines with the literal ``\n`` in the key. Ensure the entire key, including the headers and footers, is on a single line with ``\n`` used to represent line breaks.
* ``session_parameters``: Specify `session level parameters <https://docs.snowflake.com/en/user-guide/python-connector-example.html#setting-session-parameters>`_.
* ``insecure_mode``: Turn off OCSP certificate checks. For details, see: `How To: Turn Off OCSP Checking in Snowflake Client Drivers - Snowflake Community <https://community.snowflake.com/s/article/How-to-turn-off-OCSP-checking-in-Snowflake-client-drivers>`_.
* ``host``: Target Snowflake hostname to connect to (e.g., for local testing with LocalStack).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def _get_conn_params(self) -> dict[str, str | None]:
raise ValueError("The private_key_file size is too big. Please keep it less than 4 KB.")
private_key_pem = Path(private_key_file_path).read_bytes()
elif private_key_content:
private_key_pem = private_key_content.encode()
private_key_pem = private_key_content.replace("\\n", "\n").encode()

if private_key_pem:
passphrase = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,9 @@ def get_private_key(self) -> None:
elif private_key_file:
private_key_pem = Path(private_key_file).read_bytes()
elif private_key_content:
private_key_pem = private_key_content.encode()
# BS3PasswordFieldWidget treats input text literally. So, \n is converted to \\n.
# We need to convert it back to \n.
private_key_pem = private_key_content.replace("\\n", "\n").encode()

if private_key_pem:
passphrase = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,25 @@ def test_get_conn_params_should_support_private_auth_in_connection(
with mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()):
assert "private_key" in SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params

# \\n should be replaced with \n when invoking load_pem_private_key
connection_kwargs["extra"]["private_key_content"] = connection_kwargs["extra"][
"private_key_content"
].replace("\n", "\\n")
with (
mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()),
mock.patch(
"cryptography.hazmat.primitives.serialization.load_pem_private_key"
) as mock_load_pem_private_key,
):
from cryptography.hazmat.backends import default_backend

SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params
mock_load_pem_private_key.assert_called_once_with(
str(encrypted_temporary_private_key.read_text()).encode(),
password=_PASSWORD.encode(),
backend=default_backend(),
)

@pytest.mark.parametrize("include_params", [True, False])
def test_hook_param_beats_extra(self, include_params):
"""When both hook params and extras are supplied, hook params should
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,25 @@ def test_get_private_key_should_support_private_auth_in_connection(
hook.get_private_key()
assert hook.private_key is not None

# \\n should be replaced with \n when invoking load_pem_private_key
connection_kwargs["extra"]["private_key_content"] = connection_kwargs["extra"][
"private_key_content"
].replace("\n", "\\n")
with (
mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()),
mock.patch(
"cryptography.hazmat.primitives.serialization.load_pem_private_key"
) as mock_load_pem_private_key,
):
from cryptography.hazmat.backends import default_backend

SnowflakeSqlApiHook(snowflake_conn_id="test_conn")._get_conn_params
mock_load_pem_private_key.assert_called_once_with(
str(encrypted_temporary_private_key.read_text()).encode(),
password=_PASSWORD.encode(),
backend=default_backend(),
)

def test_get_private_key_raise_exception(self, encrypted_temporary_private_key: Path):
"""
Test get_private_key function with private_key_content and private_key_file in connection
Expand Down