From 1c9204b769410b125994f67265dac1346b5851a1 Mon Sep 17 00:00:00 2001 From: Tasko Olevski Date: Mon, 1 Jul 2024 09:57:17 +0200 Subject: [PATCH] fix: patching wrong environment variables when resuming (#1923) Fixes #1921. Reported by a user. The wrong environment variable was patches when the session was hibernated and the access tokens were expired. --- renku_notebooks/api/classes/k8s_client.py | 66 ++++++-------- renku_notebooks/util/kubernetes_.py | 32 +++---- tests/unit/test_k8s_client.py | 101 ++++++++++++++++++++-- 3 files changed, 131 insertions(+), 68 deletions(-) diff --git a/renku_notebooks/api/classes/k8s_client.py b/renku_notebooks/api/classes/k8s_client.py index bbdfc631b..01fdfea8d 100644 --- a/renku_notebooks/api/classes/k8s_client.py +++ b/renku_notebooks/api/classes/k8s_client.py @@ -246,18 +246,8 @@ def patch_image_pull_secret(self, server_name: str, gitlab_token: GitlabToken): patch, ) - def patch_statefulset_tokens(self, name: str, renku_tokens: RenkuTokens): - """Patch the Renku and Gitlab access tokens that are used in the session statefulset.""" - try: - sts = self._apps_v1.read_namespaced_stateful_set(name, self.namespace) - except ApiException as err: - if err.status == 404: - # NOTE: It can happen potentially that another request or something else - # deleted the session as this request was going on, in this case we ignore - # the missing statefulset - return - raise - + @staticmethod + def _get_statefulset_token_patches(sts: client.V1StatefulSet, renku_tokens: RenkuTokens) -> list[dict[str, str]]: containers: list[V1Container] = sts.spec.template.spec.containers init_containers: list[V1Container] = sts.spec.template.spec.init_containers @@ -266,15 +256,11 @@ def patch_statefulset_tokens(self, name: str, renku_tokens: RenkuTokens): (None, None), ) git_clone_container_index, git_clone_container = next( - ((i, c) for i, c in enumerate(init_containers) if c.name == "git-proxy"), + ((i, c) for i, c in enumerate(init_containers) if c.name == "git-clone"), (None, None), ) secrets_container_index, secrets_container = next( - ( - (i, c) - for i, c in enumerate(init_containers) - if c.name == "init-user-secrets" - ), + ((i, c) for i, c in enumerate(init_containers) if c.name == "init-user-secrets"), (None, None), ) @@ -294,16 +280,11 @@ def patch_statefulset_tokens(self, name: str, renku_tokens: RenkuTokens): else None ) secrets_renku_access_token_env = ( - find_env_var(secrets_container, "RENKU_ACCESS_TOKEN") - if secrets_container is not None - else None + find_env_var(secrets_container, "RENKU_ACCESS_TOKEN") if secrets_container is not None else None ) patches = list() - if ( - git_proxy_container_index is not None - and git_proxy_renku_access_token_env is not None - ): + if git_proxy_container_index is not None and git_proxy_renku_access_token_env is not None: patches.append( { "op": "replace", @@ -314,10 +295,7 @@ def patch_statefulset_tokens(self, name: str, renku_tokens: RenkuTokens): "value": renku_tokens.access_token, } ) - if ( - git_proxy_container_index is not None - and git_proxy_renku_refresh_token_env is not None - ): + if git_proxy_container_index is not None and git_proxy_renku_refresh_token_env is not None: patches.append( { "op": "replace", @@ -328,35 +306,45 @@ def patch_statefulset_tokens(self, name: str, renku_tokens: RenkuTokens): "value": renku_tokens.refresh_token, }, ) - if ( - git_clone_container_index is not None - and git_clone_renku_access_token_env is not None - ): + if git_clone_container_index is not None and git_clone_renku_access_token_env is not None: patches.append( { "op": "replace", "path": ( - f"/spec/template/spec/containers/{git_clone_container_index}" + f"/spec/template/spec/initContainers/{git_clone_container_index}" f"/env/{git_clone_renku_access_token_env[0]}/value" ), "value": renku_tokens.access_token, }, ) - if ( - secrets_container_index is not None - and secrets_renku_access_token_env is not None - ): + if secrets_container_index is not None and secrets_renku_access_token_env is not None: patches.append( { "op": "replace", "path": ( - f"/spec/template/spec/containers/{secrets_container_index}" + f"/spec/template/spec/initContainers/{secrets_container_index}" f"/env/{secrets_renku_access_token_env[0]}/value" ), "value": renku_tokens.access_token, }, ) + return patches + + def patch_statefulset_tokens(self, name: str, renku_tokens: RenkuTokens): + """Patch the Renku and Gitlab access tokens that are used in the session statefulset.""" + try: + sts = self._apps_v1.read_namespaced_stateful_set(name, self.namespace) + except ApiException as err: + if err.status == 404: + # NOTE: It can happen potentially that another request or something else + # deleted the session as this request was going on, in this case we ignore + # the missing statefulset + return + raise + + patches = self._get_statefulset_token_patches(sts, renku_tokens) + if not patches: return diff --git a/renku_notebooks/util/kubernetes_.py b/renku_notebooks/util/kubernetes_.py index 5b8076ed9..978297bf6 100644 --- a/renku_notebooks/util/kubernetes_.py +++ b/renku_notebooks/util/kubernetes_.py @@ -22,7 +22,7 @@ from typing import Any import escapism -from kubernetes.client import V1Container +from kubernetes.client import V1Container, V1EnvVarSource def filter_resources_by_annotations( @@ -37,10 +37,7 @@ def filter_resources_by_annotations( def filter_resource(resource): res = [] for annotation_name in annotations: - res.append( - resource["metadata"]["annotations"].get(annotation_name) - == annotations[annotation_name] - ) + res.append(resource["metadata"]["annotations"].get(annotation_name) == annotations[annotation_name]) if len(res) == 0: return True else: @@ -49,16 +46,12 @@ def filter_resource(resource): return list(filter(filter_resource, resources)) -def renku_1_make_server_name( - safe_username: str, namespace: str, project: str, branch: str, commit_sha: str -) -> str: +def renku_1_make_server_name(safe_username: str, namespace: str, project: str, branch: str, commit_sha: str) -> str: """Form a unique server name for Renku 1.0 sessions. This is used in naming all the k8s resources created by amalthea. """ - server_string_for_hashing = ( - f"{safe_username}-{namespace}-{project}-{branch}-{commit_sha}" - ) + server_string_for_hashing = f"{safe_username}-{namespace}-{project}-{branch}-{commit_sha}" server_hash = md5(server_string_for_hashing.encode()).hexdigest().lower() prefix = _make_server_name_prefix(safe_username) # NOTE: A K8s object name can only contain lowercase alphanumeric characters, hyphens, or dots. @@ -75,9 +68,7 @@ def renku_1_make_server_name( ) -def renku_2_make_server_name( - safe_username: str, project_id: str, launcher_id: str -) -> str: +def renku_2_make_server_name(safe_username: str, project_id: str, launcher_id: str) -> str: """Form a unique server name for Renku 2.0 sessions. This is used in naming all the k8s resources created by amalthea. @@ -95,7 +86,7 @@ def renku_2_make_server_name( return f"{prefix[:12]}-renku-2-{server_hash[:21]}" -def find_env_var(container: V1Container, env_name: str) -> tuple[int, str] | None: +def find_env_var(container: V1Container, env_name: str) -> tuple[int, str | V1EnvVarSource] | None: """Find the index and value of a specific environment variable by name from a Kubernetes container.""" env_var = next( filter( @@ -108,16 +99,15 @@ def find_env_var(container: V1Container, env_name: str) -> tuple[int, str] | Non return None ind = env_var[0] val = env_var[1].value + if val is None: + val = env_var[1].value_from return ind, val def _make_server_name_prefix(safe_username: str): safe_username_lowercase = safe_username.lower() prefix = "" - if ( - not safe_username_lowercase[0].isalpha() - or not safe_username_lowercase[0].isascii() - ): + if not safe_username_lowercase[0].isalpha() or not safe_username_lowercase[0].isascii(): # NOTE: Username starts with an invalid character. This has to be modified because a # k8s service object cannot start with anything other than a lowercase alphabet character. # NOTE: We do not have worry about collisions with already existing servers from older @@ -130,9 +120,7 @@ def _make_server_name_prefix(safe_username: str): return prefix -def find_container( - patches: list[dict[str, Any]], container_name: str -) -> dict[str, Any] | None: +def find_container(patches: list[dict[str, Any]], container_name: str) -> dict[str, Any] | None: """Find the json patch corresponding a given container.""" for patch_obj in patches: inner_patches = patch_obj.get("patch", []) diff --git a/tests/unit/test_k8s_client.py b/tests/unit/test_k8s_client.py index 6bfb91745..7d0d38479 100644 --- a/tests/unit/test_k8s_client.py +++ b/tests/unit/test_k8s_client.py @@ -1,8 +1,20 @@ import pytest - +from kubernetes.client import ( + V1Container, + V1EnvVar, + V1EnvVarSource, + V1LabelSelector, + V1PodSpec, + V1PodTemplateSpec, + V1StatefulSet, + V1StatefulSetSpec, +) + +from renku_notebooks.api.classes.auth import RenkuTokens from renku_notebooks.api.classes.k8s_client import JsServerCache, K8sClient, NamespacedK8sClient from renku_notebooks.errors.intermittent import JSCacheError from renku_notebooks.errors.programming import ProgrammingError +from renku_notebooks.util.kubernetes_ import find_env_var @pytest.fixture @@ -37,9 +49,7 @@ def test_list_cache_preference(mock_server_cache, mock_namespaced_client): renku_ns_client = mock_namespaced_client("renku") sessions_ns_client = mock_namespaced_client("renku-sessions") sample_server_manifest = {"metadata": {"labels": {"username": "username"}, "name": "server1"}} - sample_server_manifest_preferred = { - "metadata": {"labels": {"username": "username"}, "name": "preferred"} - } + sample_server_manifest_preferred = {"metadata": {"labels": {"username": "username"}, "name": "preferred"}} mock_server_cache.list_servers.return_value = [sample_server_manifest_preferred] renku_ns_client.list_servers.return_value = [] sessions_ns_client.list_servers.return_value = [sample_server_manifest] @@ -86,9 +96,7 @@ def test_get_two_results_raises_error(mock_server_cache, mock_namespaced_client) def test_get_cache_is_preferred(mock_server_cache, mock_namespaced_client): renku_ns_client = mock_namespaced_client("renku") sessions_ns_client = mock_namespaced_client("renku-sessions") - sample_server_manifest_cache = { - "metadata": {"labels": {"username": "username"}, "name": "server"} - } + sample_server_manifest_cache = {"metadata": {"labels": {"username": "username"}, "name": "server"}} sample_server_manifest_non_cache = { "metadata": { "labels": {"username": "username", "not_preferred": True}, @@ -112,3 +120,82 @@ def test_get_server_no_match(mock_server_cache, mock_namespaced_client): client = K8sClient(mock_server_cache, renku_ns_client, "username", sessions_ns_client) server = client.get_server("server", "username") assert server is None + + +def test_find_env_var(): + container = V1Container( + name="test", env=[V1EnvVar(name="key1", value="val1"), V1EnvVar(name="key2", value_from=V1EnvVarSource())] + ) + assert find_env_var(container, "key1") == (0, "val1") + assert find_env_var(container, "key2") == (1, V1EnvVarSource()) + assert find_env_var(container, "missing") is None + + +def test_patch_statefulset_tokens(): + git_clone_access_env = "GIT_CLONE_USER__RENKU_TOKEN" + git_proxy_access_env = "GIT_PROXY_RENKU_ACCESS_TOKEN" + git_proxy_refresh_env = "GIT_PROXY_RENKU_REFRESH_TOKEN" + secrets_access_env = "RENKU_ACCESS_TOKEN" + git_clone = V1Container( + name="git-clone", + env=[ + V1EnvVar(name="test", value="value"), + V1EnvVar(git_clone_access_env, "old_value"), + V1EnvVar(name="test-from-source", value_from=V1EnvVarSource()), + ], + ) + git_proxy = V1Container( + name="git-proxy", + env=[ + V1EnvVar(name="test", value="value"), + V1EnvVar(name="test-from-source", value_from=V1EnvVarSource()), + V1EnvVar(git_proxy_refresh_env, "old_value"), + V1EnvVar(git_proxy_access_env, "old_value"), + ], + ) + secrets = V1Container( + name="init-user-secrets", + env=[ + V1EnvVar(secrets_access_env, "old_value"), + V1EnvVar(name="test", value="value"), + V1EnvVar(name="test-from-source", value_from=V1EnvVarSource()), + ], + ) + random1 = V1Container(name="random1") + random2 = V1Container( + name="random2", + env=[ + V1EnvVar(name="test", value="value"), + V1EnvVar(name="test-from-source", value_from=V1EnvVarSource()), + ], + ) + + new_renku_tokens = RenkuTokens(access_token="new_renku_access_token", refresh_token="new_renku_refresh_token") + + sts = V1StatefulSet( + spec=V1StatefulSetSpec( + service_name="test", + selector=V1LabelSelector(), + template=V1PodTemplateSpec( + spec=V1PodSpec( + containers=[git_proxy, random1, random2], init_containers=[git_clone, random1, secrets, random2] + ) + ), + ) + ) + patches = NamespacedK8sClient._get_statefulset_token_patches(sts, new_renku_tokens) + + # Order of patches should be git proxy access, git proxy refresh, git clone, secrets + assert len(patches) == 4 + # Git proxy access token + assert patches[0]["path"] == "/spec/template/spec/containers/0/env/3/value" + assert patches[0]["value"] == new_renku_tokens.access_token + # Git proxy refresh token + assert patches[1]["path"] == "/spec/template/spec/containers/0/env/2/value" + assert patches[1]["value"] == new_renku_tokens.refresh_token + # Git clone + assert patches[2]["path"] == "/spec/template/spec/initContainers/0/env/1/value" + assert patches[2]["value"] == new_renku_tokens.access_token + # Secrets init + assert patches[3]["path"] == "/spec/template/spec/initContainers/2/env/0/value" + assert patches[3]["value"] == new_renku_tokens.access_token