Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Retrieve secrets using get_secret_value #19

Merged
merged 11 commits into from
Jul 26, 2022
23 changes: 18 additions & 5 deletions prefect_snowflake/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

class SnowflakeCredentials(Block):
"""
Dataclass used to manage authentication with Snowflake.
Block used to manage authentication with Snowflake.

Args:
account: The snowflake account name.
Expand All @@ -28,7 +28,17 @@ class SnowflakeCredentials(Block):
schema: The name of the default schema to use.
role: The name of the default role to use.
autocommit: Whether to automatically commit.
""" # noqa

Example:
Load stored Snowflake credentials:
```python
from prefect_snowflake import SnowflakeCredentials
snowflake_credentials_block = SnowflakeCredentials.load("BLOCK_NAME")
```
""" # noqa E501

_block_type_name = "Snowflake Credentials"
_logo_url = "https://images.ctfassets.net/gm98wzqotmnx/2DxzAeTM9eHLDcRQx1FR34/f858a501cdff918d398b39365ec2150f/snowflake.png?h=250" # noqa

account: str
user: str
Expand All @@ -46,15 +56,18 @@ def block_initialization(self):
"""
Filter out unset values.
"""
password = self.password.get_secret_value() if self.password else None
private_key = self.private_key.get_secret_value() if self.private_key else None
token = self.token.get_secret_value() if self.token else None
connect_params = {
"account": self.account,
"user": self.user,
"password": self.password,
"password": password,
"database": self.database,
"warehouse": self.warehouse,
"private_key": self.private_key,
"private_key": private_key,
"authenticator": self.authenticator,
"token": self.token,
"token": token,
"schema": self.schema_,
"role": self.role,
"autocommit": self.autocommit,
Expand Down
5 changes: 2 additions & 3 deletions tests/test_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,10 @@ def snowflake_auth(monkeypatch):

def test_snowflake_credentials_post_init(connection_params):
snowflake_credentials = SnowflakeCredentials(**connection_params)
actual_connection_params = snowflake_credentials.connect_params
for param in connection_params:
actual = actual_connection_params[param]
expected = connection_params[param]
actual = getattr(snowflake_credentials, param)
if param == "password":
actual = actual.get_secret_value()
assert actual == expected

valid_params = dir(SnowflakeCredentials)
Expand Down