diff --git a/providers/snowflake/docs/connections/snowflake.rst b/providers/snowflake/docs/connections/snowflake.rst index 2d7076d120f55..6a314279efba6 100644 --- a/providers/snowflake/docs/connections/snowflake.rst +++ b/providers/snowflake/docs/connections/snowflake.rst @@ -59,7 +59,7 @@ Extra (optional) * ``authenticator``: To connect using OAuth set this parameter ``oauth``. * ``refresh_token``: Specify refresh_token for OAuth connection. * ``private_key_file``: Specify the path to the private key file. - * ``private_key_content``: Specify the content of the private key file. + * ``private_key_content``: Specify the content of the private key file. To input a PEM-formatted private key, replace all actual newlines with the literal ``\n`` in the key. Ensure the entire key, including the headers and footers, is on a single line with ``\n`` used to represent line breaks. * ``session_parameters``: Specify `session level parameters `_. * ``insecure_mode``: Turn off OCSP certificate checks. For details, see: `How To: Turn Off OCSP Checking in Snowflake Client Drivers - Snowflake Community `_. * ``host``: Target Snowflake hostname to connect to (e.g., for local testing with LocalStack). diff --git a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py index f997d20aecdef..6267d2c3c611b 100644 --- a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py +++ b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py @@ -262,7 +262,7 @@ def _get_conn_params(self) -> dict[str, str | None]: raise ValueError("The private_key_file size is too big. Please keep it less than 4 KB.") private_key_pem = Path(private_key_file_path).read_bytes() elif private_key_content: - private_key_pem = private_key_content.encode() + private_key_pem = private_key_content.replace("\\n", "\n").encode() if private_key_pem: passphrase = None diff --git a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py index 8770492d06ec1..cf90a9eea8af3 100644 --- a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py +++ b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py @@ -120,7 +120,9 @@ def get_private_key(self) -> None: elif private_key_file: private_key_pem = Path(private_key_file).read_bytes() elif private_key_content: - private_key_pem = private_key_content.encode() + # BS3PasswordFieldWidget treats input text literally. So, \n is converted to \\n. + # We need to convert it back to \n. + private_key_pem = private_key_content.replace("\\n", "\n").encode() if private_key_pem: passphrase = None diff --git a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py index b1a65b4293b66..8ccdc24d28fc8 100644 --- a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py +++ b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py @@ -358,6 +358,25 @@ def test_get_conn_params_should_support_private_auth_in_connection( with mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()): assert "private_key" in SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params + # \\n should be replaced with \n when invoking load_pem_private_key + connection_kwargs["extra"]["private_key_content"] = connection_kwargs["extra"][ + "private_key_content" + ].replace("\n", "\\n") + with ( + mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()), + mock.patch( + "cryptography.hazmat.primitives.serialization.load_pem_private_key" + ) as mock_load_pem_private_key, + ): + from cryptography.hazmat.backends import default_backend + + SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params + mock_load_pem_private_key.assert_called_once_with( + str(encrypted_temporary_private_key.read_text()).encode(), + password=_PASSWORD.encode(), + backend=default_backend(), + ) + @pytest.mark.parametrize("include_params", [True, False]) def test_hook_param_beats_extra(self, include_params): """When both hook params and extras are supplied, hook params should diff --git a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake_sql_api.py b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake_sql_api.py index 1247e3e031820..e40f99c95d100 100644 --- a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake_sql_api.py +++ b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake_sql_api.py @@ -423,6 +423,25 @@ def test_get_private_key_should_support_private_auth_in_connection( hook.get_private_key() assert hook.private_key is not None + # \\n should be replaced with \n when invoking load_pem_private_key + connection_kwargs["extra"]["private_key_content"] = connection_kwargs["extra"][ + "private_key_content" + ].replace("\n", "\\n") + with ( + mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()), + mock.patch( + "cryptography.hazmat.primitives.serialization.load_pem_private_key" + ) as mock_load_pem_private_key, + ): + from cryptography.hazmat.backends import default_backend + + SnowflakeSqlApiHook(snowflake_conn_id="test_conn")._get_conn_params + mock_load_pem_private_key.assert_called_once_with( + str(encrypted_temporary_private_key.read_text()).encode(), + password=_PASSWORD.encode(), + backend=default_backend(), + ) + def test_get_private_key_raise_exception(self, encrypted_temporary_private_key: Path): """ Test get_private_key function with private_key_content and private_key_file in connection