Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support GCP secrets #1571

Merged
merged 4 commits into from
Mar 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 14 additions & 13 deletions flytekit/core/context_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,13 +329,13 @@ def __getattr__(self, item: str) -> _GroupSecrets:
"""
return self._GroupSecrets(item, self)

def get(self, group: str, key: str) -> str:
def get(self, group: str, key: Optional[str] = None, group_version: Optional[str] = None) -> str:
"""
Retrieves a secret using the resolution order -> Env followed by file. If not found raises a ValueError
"""
self.check_group_key(group, key)
env_var = self.get_secrets_env_var(group, key)
fpath = self.get_secrets_file(group, key)
self.check_group_key(group)
env_var = self.get_secrets_env_var(group, key, group_version)
fpath = self.get_secrets_file(group, key, group_version)
v = os.environ.get(env_var)
if v is not None:
return v
Expand All @@ -346,26 +346,27 @@ def get(self, group: str, key: str) -> str:
f"Unable to find secret for key {key} in group {group} " f"in Env Var:{env_var} and FilePath: {fpath}"
)

def get_secrets_env_var(self, group: str, key: str) -> str:
def get_secrets_env_var(self, group: str, key: Optional[str] = None, group_version: Optional[str] = None) -> str:
"""
Returns a string that matches the ENV Variable to look for the secrets
"""
self.check_group_key(group, key)
return f"{self._env_prefix}{group.upper()}_{key.upper()}"
self.check_group_key(group)
l = [k.upper() for k in filter(None, (group, group_version, key))]
return f"{self._env_prefix}{'_'.join(l)}"

def get_secrets_file(self, group: str, key: str) -> str:
def get_secrets_file(self, group: str, key: Optional[str] = None, group_version: Optional[str] = None) -> str:
"""
Returns a path that matches the file to look for the secrets
"""
self.check_group_key(group, key)
return os.path.join(self._base_dir, group.lower(), f"{self._file_prefix}{key.lower()}")
self.check_group_key(group)
l = [k.lower() for k in filter(None, (group, group_version, key))]
l[-1] = f"{self._file_prefix}{l[-1]}"
return os.path.join(self._base_dir, *l)

@staticmethod
def check_group_key(group: str, key: str):
def check_group_key(group: str):
if group is None or group == "":
raise ValueError("secrets group is a mandatory field.")
if key is None or key == "":
raise ValueError("secrets key is a mandatory field.")


@dataclass(frozen=True)
Expand Down
6 changes: 2 additions & 4 deletions flytekit/models/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,13 @@ class MountType(Enum):
"""

group: str
key: str
key: Optional[str] = None
group_version: Optional[str] = None
mount_requirement: MountType = MountType.ANY

def __post_init__(self):
if self.group is None:
raise ValueError("Group is a required parameter")
if self.key is None:
raise ValueError("Key is also a required parameter")

def to_flyte_idl(self) -> _sec.Secret:
return _sec.Secret(
Expand All @@ -59,7 +57,7 @@ def from_flyte_idl(cls, pb2_object: _sec.Secret) -> "Secret":
return cls(
group=pb2_object.group,
group_version=pb2_object.group_version if pb2_object.group_version else None,
key=pb2_object.key,
key=pb2_object.key if pb2_object.key else None,
mount_requirement=Secret.MountType(pb2_object.mount_requirement),
)

Expand Down
15 changes: 9 additions & 6 deletions tests/flytekit/unit/core/test_context_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,18 +115,17 @@ def test_secrets_manager_default():

def test_secrets_manager_get_envvar():
sec = SecretsManager()
with pytest.raises(ValueError):
sec.get_secrets_env_var("test", "")
with pytest.raises(ValueError):
sec.get_secrets_env_var("", "x")
cfg = SecretsConfig.auto()
assert sec.get_secrets_env_var("group", "test") == f"{cfg.env_prefix}GROUP_TEST"
assert sec.get_secrets_env_var("group", "test", "v1") == f"{cfg.env_prefix}GROUP_V1_TEST"
assert sec.get_secrets_env_var("group", group_version="v1") == f"{cfg.env_prefix}GROUP_V1"
assert sec.get_secrets_env_var("group") == f"{cfg.env_prefix}GROUP"


def test_secrets_manager_get_file():
sec = SecretsManager()
with pytest.raises(ValueError):
sec.get_secrets_file("test", "")
with pytest.raises(ValueError):
sec.get_secrets_file("", "x")
cfg = SecretsConfig.auto()
Expand All @@ -135,6 +134,12 @@ def test_secrets_manager_get_file():
"group",
f"{cfg.file_prefix}test",
)
assert sec.get_secrets_file("group", "test", "v1") == os.path.join(
cfg.default_dir,
"group",
"v1",
f"{cfg.file_prefix}test",
)


def test_secrets_manager_file(tmpdir: py.path.local):
Expand All @@ -145,8 +150,6 @@ def test_secrets_manager_file(tmpdir: py.path.local):
with open(f, "w+") as w:
w.write("my-password")

with pytest.raises(ValueError):
sec.get("test", "")
with pytest.raises(ValueError):
sec.get("", "x")
# Group dir not exists
Expand Down
13 changes: 13 additions & 0 deletions tests/flytekit/unit/models/core/test_security.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from flytekit.models.security import Secret


def test_secret():
obj = Secret("grp", "key")
obj2 = Secret.from_flyte_idl(obj.to_flyte_idl())
assert obj2.key == "key"
assert obj2.group_version is None

obj = Secret("grp", group_version="v1")
obj2 = Secret.from_flyte_idl(obj.to_flyte_idl())
assert obj2.key is None
assert obj2.group_version == "v1"