Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Merge pull request #25 from PrefectHQ/literal_authenticator
Browse files Browse the repository at this point in the history
Makes authenticator a literal type
  • Loading branch information
ahuang11 authored Aug 12, 2022
2 parents 807b6cc + 19a687b commit 7d5bd4c
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 7 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added
- `SnowflakeConnector` block - [#24](https://github.com/PrefectHQ/prefect-snowflake/pull/24)
- `okta_endpoint` field to `SnowflakeCredentials` - [#25](https://github.com/PrefectHQ/prefect-snowflake/pull/25)

### Changed
- Moved the keywords, `database` and `warehouse`, from `credentials.SnowflakeCredentials` into `database.SnowflakeConnector` - [#24](https://github.com/PrefectHQ/prefect-snowflake/pull/24)
- Moved the method `get_connection` from `credentials.SnowflakeCredentials` into `database.SnowflakeConnector` - [#24](https://github.com/PrefectHQ/prefect-snowflake/pull/24)
- `authenticator` field in `SnowflakeCredentials` to `Literal` type - [#25](https://github.com/PrefectHQ/prefect-snowflake/pull/25)

### Deprecated

Expand Down
53 changes: 48 additions & 5 deletions prefect_snowflake/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@

from typing import Optional

try:
from typing import Literal
except ImportError:
from typing_extensions import Literal

from prefect.blocks.core import Block
from pydantic import Field, SecretBytes, SecretStr, root_validator

Expand All @@ -22,6 +27,8 @@ class SnowflakeCredentials(Block):
work in an environment where a browser is available.
token (SecretStr): The OAuth or JWT Token to provide when
authenticator is set to OAuth.
okta_endpoint (str): The Okta endpoint to use when authenticator is
set to `okta_endpoint`, e.g. `https://<okta_account_name>.okta.com`.
role (str): The name of the default role to use.
autocommit (bool): Whether to automatically commit.
Expand All @@ -44,17 +51,26 @@ class SnowflakeCredentials(Block):
private_key: Optional[SecretBytes] = Field(
default=None, description="The PEM used to authenticate"
)
authenticator: Optional[str] = Field(
authenticator: Literal[
"snowflake",
"externalbrowser",
"okta_endpoint",
"oauth",
"username_password_mfa",
] = Field( # noqa
default="snowflake",
description=("The type of authenticator to use for initializing connection"),
)
token: Optional[SecretStr] = Field(
default=None,
description=(
"The type of authenticator to use for initializing "
"connection (oauth, externalbrowser, etc)"
"The OAuth or JWT Token to provide when authenticator is set to `oauth`"
),
)
token: Optional[SecretStr] = Field(
endpoint: Optional[str] = Field(
default=None,
description=(
"The OAuth or JWT Token to provide when " "authenticator is set to oauth"
"The Okta endpoint to use when authenticator is set to `okta_endpoint`"
),
)
role: Optional[str] = Field(
Expand All @@ -76,3 +92,30 @@ def _validate_auth_kwargs(cls, values):
f"One of the authentication keys must be provided: {auth_str}\n"
)
return values

@root_validator(pre=True)
def _validate_token_kwargs(cls, values):
"""
Ensure an authorization value has been provided by the user.
"""
authenticator = values.get("authenticator")
token = values.get("token")
if authenticator == "oauth" and not token:
raise ValueError(
"If authenticator is set to `oauth`, `token` must be provided"
)
return values

@root_validator(pre=True)
def _validate_okta_kwargs(cls, values):
"""
Ensure an authorization value has been provided by the user.
"""
authenticator = values.get("authenticator")
okta_endpoint = values.get("okta_endpoint")
if authenticator == "okta_endpoint" and not okta_endpoint:
raise ValueError(
"If authenticator is set to `okta_endpoint`, "
"`okta_endpoint` must be provided"
)
return values
4 changes: 4 additions & 0 deletions prefect_snowflake/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ def _get_connect_params(self) -> Dict[str, str]:
if param in connect_params:
connect_params[param] = connect_params[param].get_secret_value()

# set authenticator to the actual okta_endpoint
if connect_params.get("authenticator") == "okta_endpoint":
connect_params["authenticator"] = connect_params.pop("okta_endpoint")

return connect_params

def get_connection(self) -> snowflake.connector.SnowflakeConnection:
Expand Down
24 changes: 24 additions & 0 deletions tests/test_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,27 @@ def test_snowflake_credentials_validate_auth_kwargs(credentials_params):
credentials_params_missing.pop("password")
with pytest.raises(ValueError, match="One of the authentication keys"):
SnowflakeCredentials(**credentials_params_missing)


def test_snowflake_credentials_validate_token_kwargs(credentials_params):
credentials_params_missing = credentials_params.copy()
credentials_params_missing.pop("password")
credentials_params_missing["authenticator"] = "oauth"
with pytest.raises(ValueError, match="If authenticator is set to `oauth`"):
SnowflakeCredentials(**credentials_params_missing)

# now test if passing both works
credentials_params_missing["token"] = "some_token"
assert SnowflakeCredentials(**credentials_params_missing)


def test_snowflake_credentials_validate_okta_endpoint_kwargs(credentials_params):
credentials_params_missing = credentials_params.copy()
credentials_params_missing.pop("password")
credentials_params_missing["authenticator"] = "okta_endpoint"
with pytest.raises(ValueError, match="If authenticator is set to `okta_endpoint`"):
SnowflakeCredentials(**credentials_params_missing)

# now test if passing both works
credentials_params_missing["okta_endpoint"] = "https://account_name.okta.com"
assert SnowflakeCredentials(**credentials_params_missing)
16 changes: 14 additions & 2 deletions tests/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,20 @@ def test_snowflake_connector_password_is_secret_str(connector_params):

def test_snowflake_connector_get_connect_params_get_secret_value(connector_params):
snowflake_connector = SnowflakeConnector(**connector_params)
connector_params = snowflake_connector._get_connect_params()
assert connector_params["password"] == "password"
connect_params = snowflake_connector._get_connect_params()
assert connect_params["password"] == "password"


def test_snowflake_connector_get_connect_params_okta_endpoint(connector_params):
okta_endpoint = "https://account_name.okta.com"
connector_params_okta_endpoint = connector_params.copy()
connector_params_okta_endpoint["credentials"].password = None
connector_params_okta_endpoint["credentials"].authenticator = "okta_endpoint"
connector_params_okta_endpoint["credentials"].okta_endpoint = okta_endpoint
snowflake_connector = SnowflakeConnector(**connector_params_okta_endpoint)
connect_params = snowflake_connector._get_connect_params()
assert connect_params["authenticator"] == okta_endpoint
assert connect_params.get("okta_endpoint") is None


class SnowflakeCursor:
Expand Down

0 comments on commit 7d5bd4c

Please sign in to comment.