diff --git a/CHANGELOG.md b/CHANGELOG.md index ef4f9e5..6d66425 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,10 +9,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - `SnowflakeConnector` block - [#24](https://github.com/PrefectHQ/prefect-snowflake/pull/24) +- `okta_endpoint` field to `SnowflakeCredentials` - [#25](https://github.com/PrefectHQ/prefect-snowflake/pull/25) ### Changed - Moved the keywords, `database` and `warehouse`, from `credentials.SnowflakeCredentials` into `database.SnowflakeConnector` - [#24](https://github.com/PrefectHQ/prefect-snowflake/pull/24) - Moved the method `get_connection` from `credentials.SnowflakeCredentials` into `database.SnowflakeConnector` - [#24](https://github.com/PrefectHQ/prefect-snowflake/pull/24) +- `authenticator` field in `SnowflakeCredentials` to `Literal` type - [#25](https://github.com/PrefectHQ/prefect-snowflake/pull/25) ### Deprecated diff --git a/prefect_snowflake/credentials.py b/prefect_snowflake/credentials.py index aa8f4ca..819499e 100644 --- a/prefect_snowflake/credentials.py +++ b/prefect_snowflake/credentials.py @@ -2,6 +2,11 @@ from typing import Optional +try: + from typing import Literal +except ImportError: + from typing_extensions import Literal + from prefect.blocks.core import Block from pydantic import Field, SecretBytes, SecretStr, root_validator @@ -22,6 +27,8 @@ class SnowflakeCredentials(Block): work in an environment where a browser is available. token (SecretStr): The OAuth or JWT Token to provide when authenticator is set to OAuth. + okta_endpoint (str): The Okta endpoint to use when authenticator is + set to `okta_endpoint`, e.g. `https://.okta.com`. role (str): The name of the default role to use. autocommit (bool): Whether to automatically commit. @@ -44,17 +51,26 @@ class SnowflakeCredentials(Block): private_key: Optional[SecretBytes] = Field( default=None, description="The PEM used to authenticate" ) - authenticator: Optional[str] = Field( + authenticator: Literal[ + "snowflake", + "externalbrowser", + "okta_endpoint", + "oauth", + "username_password_mfa", + ] = Field( # noqa + default="snowflake", + description=("The type of authenticator to use for initializing connection"), + ) + token: Optional[SecretStr] = Field( default=None, description=( - "The type of authenticator to use for initializing " - "connection (oauth, externalbrowser, etc)" + "The OAuth or JWT Token to provide when authenticator is set to `oauth`" ), ) - token: Optional[SecretStr] = Field( + endpoint: Optional[str] = Field( default=None, description=( - "The OAuth or JWT Token to provide when " "authenticator is set to oauth" + "The Okta endpoint to use when authenticator is set to `okta_endpoint`" ), ) role: Optional[str] = Field( @@ -76,3 +92,30 @@ def _validate_auth_kwargs(cls, values): f"One of the authentication keys must be provided: {auth_str}\n" ) return values + + @root_validator(pre=True) + def _validate_token_kwargs(cls, values): + """ + Ensure an authorization value has been provided by the user. + """ + authenticator = values.get("authenticator") + token = values.get("token") + if authenticator == "oauth" and not token: + raise ValueError( + "If authenticator is set to `oauth`, `token` must be provided" + ) + return values + + @root_validator(pre=True) + def _validate_okta_kwargs(cls, values): + """ + Ensure an authorization value has been provided by the user. + """ + authenticator = values.get("authenticator") + okta_endpoint = values.get("okta_endpoint") + if authenticator == "okta_endpoint" and not okta_endpoint: + raise ValueError( + "If authenticator is set to `okta_endpoint`, " + "`okta_endpoint` must be provided" + ) + return values diff --git a/prefect_snowflake/database.py b/prefect_snowflake/database.py index 5d46554..43735d5 100644 --- a/prefect_snowflake/database.py +++ b/prefect_snowflake/database.py @@ -68,6 +68,10 @@ def _get_connect_params(self) -> Dict[str, str]: if param in connect_params: connect_params[param] = connect_params[param].get_secret_value() + # set authenticator to the actual okta_endpoint + if connect_params.get("authenticator") == "okta_endpoint": + connect_params["authenticator"] = connect_params.pop("okta_endpoint") + return connect_params def get_connection(self) -> snowflake.connector.SnowflakeConnection: diff --git a/tests/test_credentials.py b/tests/test_credentials.py index 709278d..70d5ea6 100644 --- a/tests/test_credentials.py +++ b/tests/test_credentials.py @@ -29,3 +29,27 @@ def test_snowflake_credentials_validate_auth_kwargs(credentials_params): credentials_params_missing.pop("password") with pytest.raises(ValueError, match="One of the authentication keys"): SnowflakeCredentials(**credentials_params_missing) + + +def test_snowflake_credentials_validate_token_kwargs(credentials_params): + credentials_params_missing = credentials_params.copy() + credentials_params_missing.pop("password") + credentials_params_missing["authenticator"] = "oauth" + with pytest.raises(ValueError, match="If authenticator is set to `oauth`"): + SnowflakeCredentials(**credentials_params_missing) + + # now test if passing both works + credentials_params_missing["token"] = "some_token" + assert SnowflakeCredentials(**credentials_params_missing) + + +def test_snowflake_credentials_validate_okta_endpoint_kwargs(credentials_params): + credentials_params_missing = credentials_params.copy() + credentials_params_missing.pop("password") + credentials_params_missing["authenticator"] = "okta_endpoint" + with pytest.raises(ValueError, match="If authenticator is set to `okta_endpoint`"): + SnowflakeCredentials(**credentials_params_missing) + + # now test if passing both works + credentials_params_missing["okta_endpoint"] = "https://account_name.okta.com" + assert SnowflakeCredentials(**credentials_params_missing) diff --git a/tests/test_database.py b/tests/test_database.py index 524fd1d..8b6d8a1 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -35,8 +35,20 @@ def test_snowflake_connector_password_is_secret_str(connector_params): def test_snowflake_connector_get_connect_params_get_secret_value(connector_params): snowflake_connector = SnowflakeConnector(**connector_params) - connector_params = snowflake_connector._get_connect_params() - assert connector_params["password"] == "password" + connect_params = snowflake_connector._get_connect_params() + assert connect_params["password"] == "password" + + +def test_snowflake_connector_get_connect_params_okta_endpoint(connector_params): + okta_endpoint = "https://account_name.okta.com" + connector_params_okta_endpoint = connector_params.copy() + connector_params_okta_endpoint["credentials"].password = None + connector_params_okta_endpoint["credentials"].authenticator = "okta_endpoint" + connector_params_okta_endpoint["credentials"].okta_endpoint = okta_endpoint + snowflake_connector = SnowflakeConnector(**connector_params_okta_endpoint) + connect_params = snowflake_connector._get_connect_params() + assert connect_params["authenticator"] == okta_endpoint + assert connect_params.get("okta_endpoint") is None class SnowflakeCursor: