Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Merge pull request #59 from PrefectHQ/private_key
Browse files Browse the repository at this point in the history
Add private key path
  • Loading branch information
ahuang11 authored Jan 4, 2023
2 parents dbae357 + 9576e95 commit 7cdf94d
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 3 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,16 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- `private_key_path` and `private_key_passphrase` fields to `SnowflakeConnector` - [#59](https://github.com/PrefectHQ/prefect-snowflake/pull/59)

### Changed

- Do not start connection upon instantiating `SnowflakeConnector` until its methods are called - [#58](https://github.com/PrefectHQ/prefect-snowflake/pull/58)

### Deprecated

- `password` in favor of `private_key_passphrase` field in `SnowflakeConnector` - [#59](https://github.com/PrefectHQ/prefect-snowflake/pull/59)

### Removed

### Fixed
Expand Down
43 changes: 40 additions & 3 deletions prefect_snowflake/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import re
import warnings
from pathlib import Path
from typing import Any, Dict, Optional, Union

from cryptography.hazmat.backends import default_backend
Expand Down Expand Up @@ -88,8 +89,15 @@ class SnowflakeCredentials(CredentialsBlock):
private_key: Optional[SecretBytes] = Field(
default=None, description="The PEM used to authenticate."
)
private_key_path: Optional[Path] = Field(
default=None, description="The path to the private key."
)
private_key_passphrase: Optional[SecretStr] = Field(
default=None, description="The password to use for the private key."
)
authenticator: Literal[
"snowflake",
"snowflake_jwt",
"externalbrowser",
"okta_endpoint",
"oauth",
Expand Down Expand Up @@ -122,12 +130,27 @@ def _validate_auth_kwargs(cls, values):
"""
Ensure an authorization value has been provided by the user.
"""
auth_params = ("password", "private_key", "authenticator", "token")
auth_params = (
"password",
"private_key",
"private_key_path",
"authenticator",
"token",
)
if not any(values.get(param) for param in auth_params):
auth_str = ", ".join(auth_params)
raise ValueError(
f"One of the authentication keys must be provided: {auth_str}\n"
)
elif "private_key" in values and "private_key_path" in values:
raise ValueError(
"Do not provide both private_key and private_key_path; select one."
)
elif "password" in values and "private_key_passphrase" in values:
raise ValueError(
"Do not provide both password and private_key_passphrase; "
"specify private_key_passphrase only instead."
)
return values

@root_validator(pre=True)
Expand All @@ -154,7 +177,8 @@ def _validate_okta_kwargs(cls, values):
# see https://github.com/PrefectHQ/prefect-snowflake/issues/44
if "okta_endpoint" in values.keys():
warnings.warn(
"Please specify `endpoint` instead of `okta_endpoint`.",
"Please specify `endpoint` instead of `okta_endpoint`; "
"`okta_endpoint` will be removed March 31, 2023.",
DeprecationWarning,
)
# remove okta endpoint from fields
Expand Down Expand Up @@ -201,9 +225,22 @@ def resolve_private_key(self) -> Optional[bytes]:
if private_key is None:
return None

if self.private_key_passphrase is not None:
password = self._decode_secret(self.private_key_passphrase)
elif self.password is not None:
warnings.warn(
"Using the password field for private_key is deprecated "
"and will not work after March 31, 2023; please use "
"private_key_passphrase instead",
DeprecationWarning,
)
password = self._decode_secret(self.password)
else:
password = None

return load_pem_private_key(
data=private_key,
password=self._decode_secret(self.password),
password=password,
backend=default_backend(),
).private_bytes(
encoding=Encoding.DER,
Expand Down
10 changes: 10 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,16 @@ def private_credentials_params():
}


@pytest.fixture()
def private_key_path_credentials_params():
return {
"account": "account",
"user": "user",
"private_key_path": "path/to/private/key",
"private_key_passphrase": "letmein",
}


@pytest.fixture()
def private_connector_params(private_credentials_params):

Expand Down
47 changes: 47 additions & 0 deletions tests/test_credentials.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from pathlib import Path
from unittest.mock import MagicMock

import pytest
Expand All @@ -24,6 +25,21 @@ def test_snowflake_credentials_validate_auth_kwargs(credentials_params):
SnowflakeCredentials(**credentials_params_missing)


def test_snowflake_credentials_validate_auth_kwargs_private_keys(credentials_params):
credentials_params["private_key"] = "key"
credentials_params["private_key_path"] = "keypath"
with pytest.raises(ValueError, match="Do not provide both private_key and private"):
SnowflakeCredentials(**credentials_params)


def test_snowflake_credentials_validate_auth_kwargs_private_key_password(
credentials_params,
):
credentials_params["private_key_passphrase"] = "key"
with pytest.raises(ValueError, match="Do not provide both password and private_"):
SnowflakeCredentials(**credentials_params)


def test_snowflake_credentials_validate_token_kwargs(credentials_params):
credentials_params_missing = credentials_params.copy()
credentials_params_missing.pop("password")
Expand Down Expand Up @@ -112,6 +128,22 @@ def test_snowflake_credentials_validate_private_key_password(
assert credentials.resolve_private_key() is not None


def test_snowflake_credentials_validate_private_key_passphrase(
private_credentials_params,
):
private_credentials_params[
"private_key_passphrase"
] = private_credentials_params.pop("password")
credentials_params_missing = private_credentials_params.copy()
password = credentials_params_missing.pop("private_key_passphrase")
private_key = credentials_params_missing.pop("private_key")
assert password == "letmein"
assert isinstance(private_key, bytes)
# Test cert as string
credentials = SnowflakeCredentials(**private_credentials_params)
assert credentials.resolve_private_key() is not None


def test_snowflake_credentials_validate_private_key_invalid(private_credentials_params):
credentials_params_missing = private_credentials_params.copy()
private_key = credentials_params_missing.pop("private_key")
Expand Down Expand Up @@ -173,6 +205,21 @@ def test_snowflake_credentials_validate_private_key_is_pem_bytes(
assert credentials.resolve_private_key() is not None


def test_snowflake_credentials_validate_private_key_path_init(
private_key_path_credentials_params,
):
snowflake_credentials = SnowflakeCredentials(**private_key_path_credentials_params)
actual_credentials_params = snowflake_credentials.dict()
for param in private_key_path_credentials_params:
actual = actual_credentials_params[param]
expected = private_key_path_credentials_params[param]
if isinstance(actual, (SecretStr, SecretBytes)):
actual = actual.get_secret_value()
elif isinstance(actual, Path):
actual = str(actual)
assert actual == expected


def test_get_client(credentials_params, snowflake_connect_mock: MagicMock):
snowflake_credentials = SnowflakeCredentials(**credentials_params)
snowflake_credentials.get_client()
Expand Down

0 comments on commit 7cdf94d

Please sign in to comment.