diff --git a/CHANGELOG.md b/CHANGELOG.md index 3b54073..4f95873 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,12 +9,16 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- `private_key_path` and `private_key_passphrase` fields to `SnowflakeConnector` - [#59](https://github.com/PrefectHQ/prefect-snowflake/pull/59) + ### Changed - Do not start connection upon instantiating `SnowflakeConnector` until its methods are called - [#58](https://github.com/PrefectHQ/prefect-snowflake/pull/58) ### Deprecated +- `password` in favor of `private_key_passphrase` field in `SnowflakeConnector` - [#59](https://github.com/PrefectHQ/prefect-snowflake/pull/59) + ### Removed ### Fixed diff --git a/prefect_snowflake/credentials.py b/prefect_snowflake/credentials.py index 6919c74..eac4816 100644 --- a/prefect_snowflake/credentials.py +++ b/prefect_snowflake/credentials.py @@ -2,6 +2,7 @@ import re import warnings +from pathlib import Path from typing import Any, Dict, Optional, Union from cryptography.hazmat.backends import default_backend @@ -88,8 +89,15 @@ class SnowflakeCredentials(CredentialsBlock): private_key: Optional[SecretBytes] = Field( default=None, description="The PEM used to authenticate." ) + private_key_path: Optional[Path] = Field( + default=None, description="The path to the private key." + ) + private_key_passphrase: Optional[SecretStr] = Field( + default=None, description="The password to use for the private key." + ) authenticator: Literal[ "snowflake", + "snowflake_jwt", "externalbrowser", "okta_endpoint", "oauth", @@ -122,12 +130,27 @@ def _validate_auth_kwargs(cls, values): """ Ensure an authorization value has been provided by the user. """ - auth_params = ("password", "private_key", "authenticator", "token") + auth_params = ( + "password", + "private_key", + "private_key_path", + "authenticator", + "token", + ) if not any(values.get(param) for param in auth_params): auth_str = ", ".join(auth_params) raise ValueError( f"One of the authentication keys must be provided: {auth_str}\n" ) + elif "private_key" in values and "private_key_path" in values: + raise ValueError( + "Do not provide both private_key and private_key_path; select one." + ) + elif "password" in values and "private_key_passphrase" in values: + raise ValueError( + "Do not provide both password and private_key_passphrase; " + "specify private_key_passphrase only instead." + ) return values @root_validator(pre=True) @@ -154,7 +177,8 @@ def _validate_okta_kwargs(cls, values): # see https://github.com/PrefectHQ/prefect-snowflake/issues/44 if "okta_endpoint" in values.keys(): warnings.warn( - "Please specify `endpoint` instead of `okta_endpoint`.", + "Please specify `endpoint` instead of `okta_endpoint`; " + "`okta_endpoint` will be removed March 31, 2023.", DeprecationWarning, ) # remove okta endpoint from fields @@ -201,9 +225,22 @@ def resolve_private_key(self) -> Optional[bytes]: if private_key is None: return None + if self.private_key_passphrase is not None: + password = self._decode_secret(self.private_key_passphrase) + elif self.password is not None: + warnings.warn( + "Using the password field for private_key is deprecated " + "and will not work after March 31, 2023; please use " + "private_key_passphrase instead", + DeprecationWarning, + ) + password = self._decode_secret(self.password) + else: + password = None + return load_pem_private_key( data=private_key, - password=self._decode_secret(self.password), + password=password, backend=default_backend(), ).private_bytes( encoding=Encoding.DER, diff --git a/tests/conftest.py b/tests/conftest.py index 2539127..ddef380 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -71,6 +71,16 @@ def private_credentials_params(): } +@pytest.fixture() +def private_key_path_credentials_params(): + return { + "account": "account", + "user": "user", + "private_key_path": "path/to/private/key", + "private_key_passphrase": "letmein", + } + + @pytest.fixture() def private_connector_params(private_credentials_params): diff --git a/tests/test_credentials.py b/tests/test_credentials.py index 5e9392c..6178b80 100644 --- a/tests/test_credentials.py +++ b/tests/test_credentials.py @@ -1,3 +1,4 @@ +from pathlib import Path from unittest.mock import MagicMock import pytest @@ -24,6 +25,21 @@ def test_snowflake_credentials_validate_auth_kwargs(credentials_params): SnowflakeCredentials(**credentials_params_missing) +def test_snowflake_credentials_validate_auth_kwargs_private_keys(credentials_params): + credentials_params["private_key"] = "key" + credentials_params["private_key_path"] = "keypath" + with pytest.raises(ValueError, match="Do not provide both private_key and private"): + SnowflakeCredentials(**credentials_params) + + +def test_snowflake_credentials_validate_auth_kwargs_private_key_password( + credentials_params, +): + credentials_params["private_key_passphrase"] = "key" + with pytest.raises(ValueError, match="Do not provide both password and private_"): + SnowflakeCredentials(**credentials_params) + + def test_snowflake_credentials_validate_token_kwargs(credentials_params): credentials_params_missing = credentials_params.copy() credentials_params_missing.pop("password") @@ -112,6 +128,22 @@ def test_snowflake_credentials_validate_private_key_password( assert credentials.resolve_private_key() is not None +def test_snowflake_credentials_validate_private_key_passphrase( + private_credentials_params, +): + private_credentials_params[ + "private_key_passphrase" + ] = private_credentials_params.pop("password") + credentials_params_missing = private_credentials_params.copy() + password = credentials_params_missing.pop("private_key_passphrase") + private_key = credentials_params_missing.pop("private_key") + assert password == "letmein" + assert isinstance(private_key, bytes) + # Test cert as string + credentials = SnowflakeCredentials(**private_credentials_params) + assert credentials.resolve_private_key() is not None + + def test_snowflake_credentials_validate_private_key_invalid(private_credentials_params): credentials_params_missing = private_credentials_params.copy() private_key = credentials_params_missing.pop("private_key") @@ -173,6 +205,21 @@ def test_snowflake_credentials_validate_private_key_is_pem_bytes( assert credentials.resolve_private_key() is not None +def test_snowflake_credentials_validate_private_key_path_init( + private_key_path_credentials_params, +): + snowflake_credentials = SnowflakeCredentials(**private_key_path_credentials_params) + actual_credentials_params = snowflake_credentials.dict() + for param in private_key_path_credentials_params: + actual = actual_credentials_params[param] + expected = private_key_path_credentials_params[param] + if isinstance(actual, (SecretStr, SecretBytes)): + actual = actual.get_secret_value() + elif isinstance(actual, Path): + actual = str(actual) + assert actual == expected + + def test_get_client(credentials_params, snowflake_connect_mock: MagicMock): snowflake_credentials = SnowflakeCredentials(**credentials_params) snowflake_credentials.get_client()