From bad62daaf064d10a6316741b564b1d52941a9bd8 Mon Sep 17 00:00:00 2001 From: David Pollack Date: Tue, 18 Apr 2023 09:46:18 +0200 Subject: [PATCH 1/6] Revert "Revert "Add kubernetes labels (#1236)" (#1359)" This reverts commit e68d63f96181d5916332a1295feb3a57cad7d122. --- metaflow/metaflow_config.py | 2 + metaflow/plugins/argo/argo_workflows.py | 6 +- metaflow/plugins/kubernetes/kubernetes.py | 6 +- metaflow/plugins/kubernetes/kubernetes_cli.py | 17 ++- .../kubernetes/kubernetes_decorator.py | 110 +++++++++++++++--- test/unit/test_kubernetes_decorator.py | 94 +++++++++++++++ 6 files changed, 216 insertions(+), 19 deletions(-) create mode 100644 test/unit/test_kubernetes_decorator.py diff --git a/metaflow/metaflow_config.py b/metaflow/metaflow_config.py index 7c1a6d1dfa8..7e1d8473e40 100644 --- a/metaflow/metaflow_config.py +++ b/metaflow/metaflow_config.py @@ -269,6 +269,8 @@ KUBERNETES_NODE_SELECTOR = from_conf("KUBERNETES_NODE_SELECTOR", "") KUBERNETES_TOLERATIONS = from_conf("KUBERNETES_TOLERATIONS", "") KUBERNETES_SECRETS = from_conf("KUBERNETES_SECRETS", "") +# Default labels for kubernetes pods +KUBERNETES_LABELS = from_conf("KUBERNETES_LABELS", "") # Default GPU vendor to use by K8S jobs created by Metaflow (supports nvidia, amd) KUBERNETES_GPU_VENDOR = from_conf("KUBERNETES_GPU_VENDOR", "nvidia") # Default container image for K8S diff --git a/metaflow/plugins/argo/argo_workflows.py b/metaflow/plugins/argo/argo_workflows.py index 72e9ba5ec47..0c0f6d4c084 100644 --- a/metaflow/plugins/argo/argo_workflows.py +++ b/metaflow/plugins/argo/argo_workflows.py @@ -887,8 +887,7 @@ def _container_templates(self): .retry_strategy( times=total_retries, minutes_between_retries=minutes_between_retries, - ) - .metadata( + ).metadata( ObjectMeta().annotation("metaflow/step_name", node.name) # Unfortunately, we can't set the task_id since it is generated # inside the pod. However, it can be inferred from the annotation @@ -896,11 +895,14 @@ def _container_templates(self): # the field 'task-id' in 'parameters' # .annotation("metaflow/task_id", ...) .annotation("metaflow/attempt", retry_count) + # Set labels + .labels(resources.get("labels")) ) # Set emptyDir volume for state management .empty_dir_volume("out") # Set node selectors .node_selectors(resources.get("node_selector")) + # Set tolerations .tolerations(resources.get("tolerations")) # Set container .container( diff --git a/metaflow/plugins/kubernetes/kubernetes.py b/metaflow/plugins/kubernetes/kubernetes.py index d22ef004e6d..afc48cfe663 100644 --- a/metaflow/plugins/kubernetes/kubernetes.py +++ b/metaflow/plugins/kubernetes/kubernetes.py @@ -9,6 +9,8 @@ from metaflow.metaflow_config import ( SERVICE_HEADERS, SERVICE_INTERNAL_URL, + CARD_AZUREROOT, + CARD_GSROOT, CARD_S3ROOT, DATASTORE_SYSROOT_S3, DATATOOLS_S3ROOT, @@ -29,8 +31,8 @@ BASH_SAVE_LOGS, bash_capture_logs, export_mflog_env_vars, - tail_logs, get_log_tailer, + tail_logs, ) from .kubernetes_client import KubernetesClient @@ -152,6 +154,7 @@ def create_job( run_time_limit=None, env=None, tolerations=None, + labels=None, ): if env is None: @@ -185,6 +188,7 @@ def create_job( retries=0, step_name=step_name, tolerations=tolerations, + labels=labels, ) .environment_variable("METAFLOW_CODE_SHA", code_package_sha) .environment_variable("METAFLOW_CODE_URL", code_package_url) diff --git a/metaflow/plugins/kubernetes/kubernetes_cli.py b/metaflow/plugins/kubernetes/kubernetes_cli.py index 90a7d32c579..e86cd25db2a 100644 --- a/metaflow/plugins/kubernetes/kubernetes_cli.py +++ b/metaflow/plugins/kubernetes/kubernetes_cli.py @@ -3,7 +3,7 @@ import time import traceback -from metaflow import util, JSONTypeClass +from metaflow import JSONTypeClass, util from metaflow._vendor import click from metaflow.exception import METAFLOW_EXIT_DISALLOW_RETRY, CommandException from metaflow.metadata.util import sync_local_metadata_from_datastore @@ -91,6 +91,12 @@ def kubernetes(): type=JSONTypeClass(), multiple=False, ) +@click.option( + "--labels", + multiple=True, + default=None, + help="Labels for Kubernetes pod.", +) @click.pass_context def step( ctx, @@ -110,6 +116,7 @@ def step( gpu_vendor=None, run_time_limit=None, tolerations=None, + labels=None, **kwargs ): def echo(msg, stream="stderr", job_id=None, **kwargs): @@ -175,7 +182,12 @@ def echo(msg, stream="stderr", job_id=None, **kwargs): stderr_location = ds.get_log_location(TASK_LOG_SOURCE, "stderr") # `node_selector` is a tuple of strings, convert it to a dictionary - node_selector = KubernetesDecorator.parse_node_selector(node_selector) + node_selector = KubernetesDecorator.parse_kube_keyvalue_list(node_selector) + + # `labels` is a tuple of strings or a tuple with a single comma separated string + # convert it to a dict + labels = KubernetesDecorator.parse_kube_keyvalue_list(labels, False) + KubernetesDecorator.validate_kube_labels(labels) def _sync_metadata(): if ctx.obj.metadata.TYPE == "local": @@ -218,6 +230,7 @@ def _sync_metadata(): run_time_limit=run_time_limit, env=env, tolerations=tolerations, + labels=labels, ) except Exception as e: traceback.print_exc(chain=False) diff --git a/metaflow/plugins/kubernetes/kubernetes_decorator.py b/metaflow/plugins/kubernetes/kubernetes_decorator.py index ef508853219..e182423aa7f 100644 --- a/metaflow/plugins/kubernetes/kubernetes_decorator.py +++ b/metaflow/plugins/kubernetes/kubernetes_decorator.py @@ -1,7 +1,10 @@ +import hashlib import json import os import platform +import re import sys +from typing import Dict, List, Optional, Union from metaflow.decorators import StepDecorator from metaflow.exception import MetaflowException @@ -12,11 +15,12 @@ KUBERNETES_CONTAINER_IMAGE, KUBERNETES_CONTAINER_REGISTRY, KUBERNETES_GPU_VENDOR, + KUBERNETES_LABELS, KUBERNETES_NAMESPACE, KUBERNETES_NODE_SELECTOR, KUBERNETES_TOLERATIONS, - KUBERNETES_SERVICE_ACCOUNT, KUBERNETES_SECRETS, + KUBERNETES_SERVICE_ACCOUNT, KUBERNETES_FETCH_EC2_METADATA, ) from metaflow.plugins.resources_decorator import ResourcesDecorator @@ -65,6 +69,8 @@ class KubernetesDecorator(StepDecorator): in Metaflow configuration. tolerations : List[str], default: METAFLOW_KUBERNETES_TOLERATIONS Kubernetes tolerations to use when launching pod in Kubernetes. + labels : Dict[str, str], default: METAFLOW_KUBERNETES_LABELS + Kubernetes labels to use when launching pod in Kubernetes. """ name = "kubernetes" @@ -76,6 +82,7 @@ class KubernetesDecorator(StepDecorator): "service_account": None, "secrets": None, # e.g., mysecret "node_selector": None, # e.g., kubernetes.io/os=linux + "labels": None, # e.g., my_label=my_value "namespace": None, "gpu": None, # value of 0 implies that the scheduled node should not have GPUs "gpu_vendor": None, @@ -99,9 +106,17 @@ def __init__(self, attributes=None, statically_defined=False): self.attributes["node_selector"] = KUBERNETES_NODE_SELECTOR if not self.attributes["tolerations"] and KUBERNETES_TOLERATIONS: self.attributes["tolerations"] = json.loads(KUBERNETES_TOLERATIONS) + if not self.attributes["labels"] and KUBERNETES_LABELS: + self.attributes["labels"] = KUBERNETES_LABELS + + if isinstance(self.attributes["labels"], str): + self.attributes["labels"] = self.parse_kube_keyvalue_list( + self.attributes["labels"].split(","), False + ) + self.validate_kube_labels(self.attributes["labels"]) if isinstance(self.attributes["node_selector"], str): - self.attributes["node_selector"] = self.parse_node_selector( + self.attributes["node_selector"] = self.parse_kube_keyvalue_list( self.attributes["node_selector"].split(",") ) @@ -280,10 +295,11 @@ def runtime_step_cli( for k, v in self.attributes.items(): if k == "namespace": cli_args.command_options["k8s_namespace"] = v - elif k == "node_selector" and v: - cli_args.command_options[k] = ",".join( - ["=".join([key, str(val)]) for key, val in v.items()] - ) + elif k in {"node_selector", "labels"} and v: + cli_args.command_options[k] = [ + "=".join([key, str(val)]) if val else key + for key, val in v.items() + ] elif k == "tolerations": cli_args.command_options[k] = json.dumps(v) else: @@ -391,14 +407,80 @@ def _save_package_once(cls, flow_datastore, package): [package.blob], len_hint=1 )[0] + @classmethod + def _parse_decorator_spec(cls, deco_spec: str): + if not deco_spec: + return cls() + + valid_options = "|".join(cls.defaults.keys()) + deco_spec_parts = [] + for part in re.split(""",(?=[\s\w]+[{}]=)""".format(valid_options), deco_spec): + name, val = part.split("=", 1) + if name in {"labels", "node_selector"}: + try: + tmp_vals = json.loads(val.strip().replace('\\"', '"')) + for val_i in tmp_vals.values(): + if not (val_i is None or isinstance(val_i, str)): + raise KubernetesException( + "All values must be string or null." + ) + except json.JSONDecodeError: + if val.startswith("{"): + raise KubernetesException( + "Malform json detected in %s" % str(val) + ) + both = name == "node_selector" + val = json.dumps( + cls.parse_kube_keyvalue_list(val.split(","), both), + separators=(",", ":"), + ) + deco_spec_parts.append("=".join([name, val])) + deco_spec_parsed = ",".join(deco_spec_parts) + return super()._parse_decorator_spec(deco_spec_parsed) + @staticmethod - def parse_node_selector(node_selector: list): + def parse_kube_keyvalue_list(items: List[str], requires_both: bool = True): try: - return { - str(k.split("=", 1)[0]): str(k.split("=", 1)[1]) - for k in node_selector or [] - } + ret = {} + for item_str in items: + item = item_str.split("=", 1) + if requires_both: + item[1] # raise IndexError + if str(item[0]) in ret: + raise KubernetesException("Duplicate key found: %s" % str(item[0])) + ret[str(item[0])] = str(item[1]) if len(item) > 1 else None + return ret + except KubernetesException as e: + raise e except (AttributeError, IndexError): - raise KubernetesException( - "Unable to parse node_selector: %s" % node_selector - ) + raise KubernetesException("Unable to parse kubernetes list: %s" % items) + + @staticmethod + def validate_kube_labels( + labels: Optional[Dict[str, Optional[str]]], + ) -> bool: + """Validate label values. + + This validates the kubernetes label values. It does not validate the keys. + Ideally, keys should be static and also the validation rules for keys are + more complex than those for values. For full validation rules, see: + + https://kubernetes.io/docs/concepts/overview/working-with-objects/labels/#syntax-and-character-set + """ + + def validate_label(s: Optional[str]): + regex_match = r"^(([A-Za-z0-9][-A-Za-z0-9_.]{0,61})?[A-Za-z0-9])?$" + if not s: + # allow empty label + return True + if not re.search(regex_match, s): + raise KubernetesException( + 'Invalid value: "%s"\n' + "A valid label must be an empty string or one that\n" + " - Consist of alphanumeric, '-', '_' or '.' characters\n" + " - Begins and ends with an alphanumeric character\n" + " - Is at most 63 characters" % s + ) + return True + + return all([validate_label(v) for v in labels.values()]) if labels else True diff --git a/test/unit/test_kubernetes_decorator.py b/test/unit/test_kubernetes_decorator.py new file mode 100644 index 00000000000..f779faceb9f --- /dev/null +++ b/test/unit/test_kubernetes_decorator.py @@ -0,0 +1,94 @@ +import pytest + +from metaflow.plugins.kubernetes.kubernetes import KubernetesException +from metaflow.plugins.kubernetes.kubernetes_decorator import KubernetesDecorator + + +@pytest.mark.parametrize( + "labels", + [ + None, + {"label": "value"}, + {"label1": "val1", "label2": "val2"}, + {"label1": "val1", "label2": None}, + {"label": "a"}, + {"label": ""}, + { + "label": ( + "1234567890" + "1234567890" + "1234567890" + "1234567890" + "1234567890" + "1234567890" + "123" + ) + }, + { + "label": ( + "1234567890" + "1234567890" + "1234-_.890" + "1234567890" + "1234567890" + "1234567890" + "123" + ) + }, + ], +) +def test_kubernetes_decorator_validate_kube_labels(labels): + assert KubernetesDecorator.validate_kube_labels(labels) + + +@pytest.mark.parametrize( + "labels", + [ + {"label": "a-"}, + {"label": ".a"}, + {"label": "test()"}, + { + "label": ( + "1234567890" + "1234567890" + "1234567890" + "1234567890" + "1234567890" + "1234567890" + "1234" + ) + }, + {"label": "(){}??"}, + {"valid": "test", "invalid": "bißchen"}, + ], +) +def test_kubernetes_decorator_validate_kube_labels_fail(labels): + """Fail if label contains invalid characters or is too long""" + with pytest.raises(KubernetesException): + KubernetesDecorator.validate_kube_labels(labels) + + +@pytest.mark.parametrize( + "items,requires_both,expected", + [ + (["key=value"], True, {"key": "value"}), + (["key=value"], False, {"key": "value"}), + (["key"], False, {"key": None}), + (["key=value", "key2=value2"], True, {"key": "value", "key2": "value2"}), + ], +) +def test_kubernetes_parse_keyvalue_list(items, requires_both, expected): + ret = KubernetesDecorator.parse_kube_keyvalue_list(items, requires_both) + assert ret == expected + + +@pytest.mark.parametrize( + "items,requires_both", + [ + (["key=value", "key=value2"], True), + (["key"], True), + ], +) +def test_kubernetes_parse_keyvalue_list(items, requires_both): + with pytest.raises(KubernetesException): + KubernetesDecorator.parse_kube_keyvalue_list(items, requires_both) From 412e2532b6dd7fe817b497b7db8c418c7752a042 Mon Sep 17 00:00:00 2001 From: David Pollack Date: Tue, 18 Apr 2023 09:52:00 +0200 Subject: [PATCH 2/6] Fix null labels value in argo --- metaflow/plugins/argo/argo_workflows.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/metaflow/plugins/argo/argo_workflows.py b/metaflow/plugins/argo/argo_workflows.py index 0c0f6d4c084..e9c0337b822 100644 --- a/metaflow/plugins/argo/argo_workflows.py +++ b/metaflow/plugins/argo/argo_workflows.py @@ -1064,7 +1064,7 @@ def label(self, key, value): def labels(self, labels): if "labels" not in self.payload: self.payload["labels"] = {} - self.payload["labels"].update(labels) + self.payload["labels"].update(labels or {}) return self def name(self, name): @@ -1173,7 +1173,7 @@ def label(self, key, value): def labels(self, labels): if "labels" not in self.payload: self.payload["labels"] = {} - self.payload["labels"].update(labels) + self.payload["labels"].update(labels or {}) return self def labels_from(self, labels_from): From af39a0e660fb609f4b1ae77fe7080a7c0d869f78 Mon Sep 17 00:00:00 2001 From: David Pollack Date: Thu, 27 Apr 2023 14:50:44 +0200 Subject: [PATCH 3/6] Only allow env vars to add labels --- metaflow/plugins/argo/argo_workflows.py | 11 ++ metaflow/plugins/kubernetes/kubernetes.py | 64 ++++++++++- metaflow/plugins/kubernetes/kubernetes_cli.py | 19 +--- .../kubernetes/kubernetes_decorator.py | 100 +----------------- ...rnetes_decorator.py => test_kubernetes.py} | 15 +-- 5 files changed, 90 insertions(+), 119 deletions(-) rename test/unit/{test_kubernetes_decorator.py => test_kubernetes.py} (83%) diff --git a/metaflow/plugins/argo/argo_workflows.py b/metaflow/plugins/argo/argo_workflows.py index 3758e182dab..10e655c197b 100644 --- a/metaflow/plugins/argo/argo_workflows.py +++ b/metaflow/plugins/argo/argo_workflows.py @@ -30,6 +30,7 @@ DEFAULT_METADATA, DEFAULT_SECRETS_BACKEND_TYPE, KUBERNETES_FETCH_EC2_METADATA, + KUBERNETES_LABELS, KUBERNETES_NAMESPACE, KUBERNETES_NODE_SELECTOR, KUBERNETES_SANDBOX_INIT_SCRIPT, @@ -40,6 +41,10 @@ ) from metaflow.mflog import BASH_SAVE_LOGS, bash_capture_logs, export_mflog_env_vars from metaflow.parameters import deploy_time_eval +from metaflow.plugins.kubernetes.kubernetes import ( + parse_kube_keyvalue_list, + validate_kube_labels, +) from metaflow.util import ( compress_list, dict_to_cli_options, @@ -469,6 +474,11 @@ def _compile_workflow_template(self): } ) + # get labels from env vars + env_labels = KUBERNETES_LABELS.split(",") + env_labels = parse_kube_keyvalue_list(env_labels, False) + validate_kube_labels(env_labels) + return ( WorkflowTemplate() .metadata( @@ -550,6 +560,7 @@ def _compile_workflow_template(self): .label("app.kubernetes.io/name", "metaflow-task") .label("app.kubernetes.io/part-of", "metaflow") .annotations(annotations) + .labels(env_labels) ) # Set the entrypoint to flow name .entrypoint(self.flow.name) diff --git a/metaflow/plugins/kubernetes/kubernetes.py b/metaflow/plugins/kubernetes/kubernetes.py index eacefebc03a..77ea167939c 100644 --- a/metaflow/plugins/kubernetes/kubernetes.py +++ b/metaflow/plugins/kubernetes/kubernetes.py @@ -1,8 +1,10 @@ import json import math import os +import re import shlex import time +from typing import Dict, List, Optional from metaflow import current, util from metaflow.exception import MetaflowException @@ -20,8 +22,9 @@ DEFAULT_AWS_CLIENT_PROVIDER, DEFAULT_METADATA, DEFAULT_SECRETS_BACKEND_TYPE, - KUBERNETES_SANDBOX_INIT_SCRIPT, KUBERNETES_FETCH_EC2_METADATA, + KUBERNETES_LABELS, + KUBERNETES_SANDBOX_INIT_SCRIPT, S3_ENDPOINT_URL, SERVICE_HEADERS, SERVICE_INTERNAL_URL, @@ -190,7 +193,7 @@ def create_job( retries=0, step_name=step_name, tolerations=tolerations, - labels=labels, + labels=self._get_labels(labels), use_tmpfs=use_tmpfs, tmpfs_tempdir=tmpfs_tempdir, tmpfs_size=tmpfs_size, @@ -376,3 +379,60 @@ def wait_for_launch(job): "stderr", job_id=self._job.id, ) + + @staticmethod + def _get_labels(extra_labels=None): + if extra_labels is None: + extra_labels = {} + env_labels = KUBERNETES_LABELS.split(",") + env_labels = parse_kube_keyvalue_list(env_labels, False) + labels = {**env_labels, **extra_labels} + validate_kube_labels(labels) + return labels + + +def validate_kube_labels( + labels: Optional[Dict[str, Optional[str]]], +) -> bool: + """Validate label values. + + This validates the kubernetes label values. It does not validate the keys. + Ideally, keys should be static and also the validation rules for keys are + more complex than those for values. For full validation rules, see: + + https://kubernetes.io/docs/concepts/overview/working-with-objects/labels/#syntax-and-character-set + """ + + def validate_label(s: Optional[str]): + regex_match = r"^(([A-Za-z0-9][-A-Za-z0-9_.]{0,61})?[A-Za-z0-9])?$" + if not s: + # allow empty label + return True + if not re.search(regex_match, s): + raise KubernetesException( + 'Invalid value: "%s"\n' + "A valid label must be an empty string or one that\n" + " - Consist of alphanumeric, '-', '_' or '.' characters\n" + " - Begins and ends with an alphanumeric character\n" + " - Is at most 63 characters" % s + ) + return True + + return all([validate_label(v) for v in labels.values()]) if labels else True + + +def parse_kube_keyvalue_list(items: List[str], requires_both: bool = True): + try: + ret = {} + for item_str in items: + item = item_str.split("=", 1) + if requires_both: + item[1] # raise IndexError + if str(item[0]) in ret: + raise KubernetesException("Duplicate key found: %s" % str(item[0])) + ret[str(item[0])] = str(item[1]) if len(item) > 1 else None + return ret + except KubernetesException as e: + raise e + except (AttributeError, IndexError): + raise KubernetesException("Unable to parse kubernetes list: %s" % items) diff --git a/metaflow/plugins/kubernetes/kubernetes_cli.py b/metaflow/plugins/kubernetes/kubernetes_cli.py index 7d0655831af..ec9d4f5f36d 100644 --- a/metaflow/plugins/kubernetes/kubernetes_cli.py +++ b/metaflow/plugins/kubernetes/kubernetes_cli.py @@ -7,10 +7,10 @@ from metaflow._vendor import click from metaflow.exception import METAFLOW_EXIT_DISALLOW_RETRY, CommandException from metaflow.metadata.util import sync_local_metadata_from_datastore -from metaflow.metaflow_config import DATASTORE_LOCAL_DIR +from metaflow.metaflow_config import DATASTORE_LOCAL_DIR, KUBERNETES_LABELS from metaflow.mflog import TASK_LOG_SOURCE -from .kubernetes import Kubernetes, KubernetesKilledException +from .kubernetes import Kubernetes, KubernetesKilledException, parse_kube_keyvalue_list from .kubernetes_decorator import KubernetesDecorator @@ -97,12 +97,6 @@ def kubernetes(): type=JSONTypeClass(), multiple=False, ) -@click.option( - "--labels", - multiple=True, - default=None, - help="Labels for Kubernetes pod.", -) @click.pass_context def step( ctx, @@ -126,7 +120,6 @@ def step( tmpfs_path=None, run_time_limit=None, tolerations=None, - labels=None, **kwargs ): def echo(msg, stream="stderr", job_id=None, **kwargs): @@ -192,12 +185,7 @@ def echo(msg, stream="stderr", job_id=None, **kwargs): stderr_location = ds.get_log_location(TASK_LOG_SOURCE, "stderr") # `node_selector` is a tuple of strings, convert it to a dictionary - node_selector = KubernetesDecorator.parse_kube_keyvalue_list(node_selector) - - # `labels` is a tuple of strings or a tuple with a single comma separated string - # convert it to a dict - labels = KubernetesDecorator.parse_kube_keyvalue_list(labels, False) - KubernetesDecorator.validate_kube_labels(labels) + node_selector = parse_kube_keyvalue_list(node_selector) def _sync_metadata(): if ctx.obj.metadata.TYPE == "local": @@ -244,7 +232,6 @@ def _sync_metadata(): run_time_limit=run_time_limit, env=env, tolerations=tolerations, - labels=labels, ) except Exception as e: traceback.print_exc(chain=False) diff --git a/metaflow/plugins/kubernetes/kubernetes_decorator.py b/metaflow/plugins/kubernetes/kubernetes_decorator.py index 06d03258402..0aa119415b5 100644 --- a/metaflow/plugins/kubernetes/kubernetes_decorator.py +++ b/metaflow/plugins/kubernetes/kubernetes_decorator.py @@ -15,22 +15,21 @@ DATASTORE_LOCAL_DIR, KUBERNETES_CONTAINER_IMAGE, KUBERNETES_CONTAINER_REGISTRY, + KUBERNETES_FETCH_EC2_METADATA, KUBERNETES_GPU_VENDOR, KUBERNETES_LABELS, KUBERNETES_NAMESPACE, KUBERNETES_NODE_SELECTOR, - KUBERNETES_TOLERATIONS, KUBERNETES_SECRETS, KUBERNETES_SERVICE_ACCOUNT, - KUBERNETES_FETCH_EC2_METADATA, + KUBERNETES_TOLERATIONS, ) from metaflow.plugins.resources_decorator import ResourcesDecorator from metaflow.plugins.timeout_decorator import get_run_time_limit_for_task from metaflow.sidecar import Sidecar from ..aws.aws_utils import get_docker_registry, get_ec2_instance_metadata - -from .kubernetes import KubernetesException +from .kubernetes import KubernetesException, parse_kube_keyvalue_list try: unicode @@ -70,8 +69,6 @@ class KubernetesDecorator(StepDecorator): in Metaflow configuration. tolerations : List[str], default: METAFLOW_KUBERNETES_TOLERATIONS Kubernetes tolerations to use when launching pod in Kubernetes. - labels : Dict[str, str], default: METAFLOW_KUBERNETES_LABELS - Kubernetes labels to use when launching pod in Kubernetes. use_tmpfs: bool, default: False This enables an explicit tmpfs mount for this step. tmpfs_tempdir: bool, default: True @@ -93,7 +90,6 @@ class KubernetesDecorator(StepDecorator): "service_account": None, "secrets": None, # e.g., mysecret "node_selector": None, # e.g., kubernetes.io/os=linux - "labels": None, # e.g., my_label=my_value "namespace": None, "gpu": None, # value of 0 implies that the scheduled node should not have GPUs "gpu_vendor": None, @@ -121,17 +117,9 @@ def __init__(self, attributes=None, statically_defined=False): self.attributes["node_selector"] = KUBERNETES_NODE_SELECTOR if not self.attributes["tolerations"] and KUBERNETES_TOLERATIONS: self.attributes["tolerations"] = json.loads(KUBERNETES_TOLERATIONS) - if not self.attributes["labels"] and KUBERNETES_LABELS: - self.attributes["labels"] = KUBERNETES_LABELS - - if isinstance(self.attributes["labels"], str): - self.attributes["labels"] = self.parse_kube_keyvalue_list( - self.attributes["labels"].split(","), False - ) - self.validate_kube_labels(self.attributes["labels"]) if isinstance(self.attributes["node_selector"], str): - self.attributes["node_selector"] = self.parse_kube_keyvalue_list( + self.attributes["node_selector"] = parse_kube_keyvalue_list( self.attributes["node_selector"].split(",") ) @@ -328,7 +316,7 @@ def runtime_step_cli( for k, v in self.attributes.items(): if k == "namespace": cli_args.command_options["k8s_namespace"] = v - elif k in {"node_selector", "labels"} and v: + elif k in {"node_selector"} and v: cli_args.command_options[k] = [ "=".join([key, str(val)]) if val else key for key, val in v.items() @@ -444,81 +432,3 @@ def _save_package_once(cls, flow_datastore, package): cls.package_url, cls.package_sha = flow_datastore.save_data( [package.blob], len_hint=1 )[0] - - @classmethod - def _parse_decorator_spec(cls, deco_spec: str): - if not deco_spec: - return cls() - - valid_options = "|".join(cls.defaults.keys()) - deco_spec_parts = [] - for part in re.split(""",(?=[\s\w]+[{}]=)""".format(valid_options), deco_spec): - name, val = part.split("=", 1) - if name in {"labels", "node_selector"}: - try: - tmp_vals = json.loads(val.strip().replace('\\"', '"')) - for val_i in tmp_vals.values(): - if not (val_i is None or isinstance(val_i, str)): - raise KubernetesException( - "All values must be string or null." - ) - except json.JSONDecodeError: - if val.startswith("{"): - raise KubernetesException( - "Malform json detected in %s" % str(val) - ) - both = name == "node_selector" - val = json.dumps( - cls.parse_kube_keyvalue_list(val.split(","), both), - separators=(",", ":"), - ) - deco_spec_parts.append("=".join([name, val])) - deco_spec_parsed = ",".join(deco_spec_parts) - return super()._parse_decorator_spec(deco_spec_parsed) - - @staticmethod - def parse_kube_keyvalue_list(items: List[str], requires_both: bool = True): - try: - ret = {} - for item_str in items: - item = item_str.split("=", 1) - if requires_both: - item[1] # raise IndexError - if str(item[0]) in ret: - raise KubernetesException("Duplicate key found: %s" % str(item[0])) - ret[str(item[0])] = str(item[1]) if len(item) > 1 else None - return ret - except KubernetesException as e: - raise e - except (AttributeError, IndexError): - raise KubernetesException("Unable to parse kubernetes list: %s" % items) - - @staticmethod - def validate_kube_labels( - labels: Optional[Dict[str, Optional[str]]], - ) -> bool: - """Validate label values. - - This validates the kubernetes label values. It does not validate the keys. - Ideally, keys should be static and also the validation rules for keys are - more complex than those for values. For full validation rules, see: - - https://kubernetes.io/docs/concepts/overview/working-with-objects/labels/#syntax-and-character-set - """ - - def validate_label(s: Optional[str]): - regex_match = r"^(([A-Za-z0-9][-A-Za-z0-9_.]{0,61})?[A-Za-z0-9])?$" - if not s: - # allow empty label - return True - if not re.search(regex_match, s): - raise KubernetesException( - 'Invalid value: "%s"\n' - "A valid label must be an empty string or one that\n" - " - Consist of alphanumeric, '-', '_' or '.' characters\n" - " - Begins and ends with an alphanumeric character\n" - " - Is at most 63 characters" % s - ) - return True - - return all([validate_label(v) for v in labels.values()]) if labels else True diff --git a/test/unit/test_kubernetes_decorator.py b/test/unit/test_kubernetes.py similarity index 83% rename from test/unit/test_kubernetes_decorator.py rename to test/unit/test_kubernetes.py index f779faceb9f..2169bfd4e0b 100644 --- a/test/unit/test_kubernetes_decorator.py +++ b/test/unit/test_kubernetes.py @@ -1,7 +1,10 @@ import pytest -from metaflow.plugins.kubernetes.kubernetes import KubernetesException -from metaflow.plugins.kubernetes.kubernetes_decorator import KubernetesDecorator +from metaflow.plugins.kubernetes.kubernetes import ( + KubernetesException, + validate_kube_labels, + parse_kube_keyvalue_list, +) @pytest.mark.parametrize( @@ -38,7 +41,7 @@ ], ) def test_kubernetes_decorator_validate_kube_labels(labels): - assert KubernetesDecorator.validate_kube_labels(labels) + assert validate_kube_labels(labels) @pytest.mark.parametrize( @@ -65,7 +68,7 @@ def test_kubernetes_decorator_validate_kube_labels(labels): def test_kubernetes_decorator_validate_kube_labels_fail(labels): """Fail if label contains invalid characters or is too long""" with pytest.raises(KubernetesException): - KubernetesDecorator.validate_kube_labels(labels) + validate_kube_labels(labels) @pytest.mark.parametrize( @@ -78,7 +81,7 @@ def test_kubernetes_decorator_validate_kube_labels_fail(labels): ], ) def test_kubernetes_parse_keyvalue_list(items, requires_both, expected): - ret = KubernetesDecorator.parse_kube_keyvalue_list(items, requires_both) + ret = parse_kube_keyvalue_list(items, requires_both) assert ret == expected @@ -91,4 +94,4 @@ def test_kubernetes_parse_keyvalue_list(items, requires_both, expected): ) def test_kubernetes_parse_keyvalue_list(items, requires_both): with pytest.raises(KubernetesException): - KubernetesDecorator.parse_kube_keyvalue_list(items, requires_both) + parse_kube_keyvalue_list(items, requires_both) From 542dda3746c17e1ef8844beaf742dac07a7bb5a4 Mon Sep 17 00:00:00 2001 From: David Pollack Date: Thu, 27 Apr 2023 15:02:29 +0200 Subject: [PATCH 4/6] Fix empty label case --- metaflow/plugins/argo/argo_workflows.py | 2 +- metaflow/plugins/kubernetes/kubernetes.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/metaflow/plugins/argo/argo_workflows.py b/metaflow/plugins/argo/argo_workflows.py index 10e655c197b..88f9d39ab38 100644 --- a/metaflow/plugins/argo/argo_workflows.py +++ b/metaflow/plugins/argo/argo_workflows.py @@ -475,7 +475,7 @@ def _compile_workflow_template(self): ) # get labels from env vars - env_labels = KUBERNETES_LABELS.split(",") + env_labels = KUBERNETES_LABELS.split(",") if KUBERNETES_LABELS else {} env_labels = parse_kube_keyvalue_list(env_labels, False) validate_kube_labels(env_labels) diff --git a/metaflow/plugins/kubernetes/kubernetes.py b/metaflow/plugins/kubernetes/kubernetes.py index 77ea167939c..78d38ab6471 100644 --- a/metaflow/plugins/kubernetes/kubernetes.py +++ b/metaflow/plugins/kubernetes/kubernetes.py @@ -384,7 +384,7 @@ def wait_for_launch(job): def _get_labels(extra_labels=None): if extra_labels is None: extra_labels = {} - env_labels = KUBERNETES_LABELS.split(",") + env_labels = KUBERNETES_LABELS.split(",") if KUBERNETES_LABELS else {} env_labels = parse_kube_keyvalue_list(env_labels, False) labels = {**env_labels, **extra_labels} validate_kube_labels(labels) From ca685ed83905df22209e200a0bfafad7bf3f5e6d Mon Sep 17 00:00:00 2001 From: David Pollack Date: Thu, 27 Apr 2023 15:43:40 +0200 Subject: [PATCH 5/6] Adjustments from PR comments --- metaflow/plugins/argo/argo_workflows.py | 4 +--- metaflow/plugins/kubernetes/kubernetes.py | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/metaflow/plugins/argo/argo_workflows.py b/metaflow/plugins/argo/argo_workflows.py index 88f9d39ab38..a39ee18638c 100644 --- a/metaflow/plugins/argo/argo_workflows.py +++ b/metaflow/plugins/argo/argo_workflows.py @@ -475,7 +475,7 @@ def _compile_workflow_template(self): ) # get labels from env vars - env_labels = KUBERNETES_LABELS.split(",") if KUBERNETES_LABELS else {} + env_labels = KUBERNETES_LABELS.split(",") if KUBERNETES_LABELS else [] env_labels = parse_kube_keyvalue_list(env_labels, False) validate_kube_labels(env_labels) @@ -1121,8 +1121,6 @@ def _container_templates(self): # the field 'task-id' in 'parameters' # .annotation("metaflow/task_id", ...) .annotation("metaflow/attempt", retry_count) - # Set labels - .labels(resources.get("labels")) ) # Set emptyDir volume for state management .empty_dir_volume("out") diff --git a/metaflow/plugins/kubernetes/kubernetes.py b/metaflow/plugins/kubernetes/kubernetes.py index 78d38ab6471..0f83cd7cd46 100644 --- a/metaflow/plugins/kubernetes/kubernetes.py +++ b/metaflow/plugins/kubernetes/kubernetes.py @@ -384,7 +384,7 @@ def wait_for_launch(job): def _get_labels(extra_labels=None): if extra_labels is None: extra_labels = {} - env_labels = KUBERNETES_LABELS.split(",") if KUBERNETES_LABELS else {} + env_labels = KUBERNETES_LABELS.split(",") if KUBERNETES_LABELS else [] env_labels = parse_kube_keyvalue_list(env_labels, False) labels = {**env_labels, **extra_labels} validate_kube_labels(labels) From bf70791be71b1273abe47daad5996e0095032069 Mon Sep 17 00:00:00 2001 From: Sakari Ikonen Date: Thu, 27 Apr 2023 22:07:50 +0300 Subject: [PATCH 6/6] refactor argo kubernetes label getter. add labels to argo-workflow sensors --- metaflow/plugins/argo/argo_workflows.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/metaflow/plugins/argo/argo_workflows.py b/metaflow/plugins/argo/argo_workflows.py index a39ee18638c..57927bc8201 100644 --- a/metaflow/plugins/argo/argo_workflows.py +++ b/metaflow/plugins/argo/argo_workflows.py @@ -146,6 +146,7 @@ def __init__( self.triggers, self.trigger_options = self._process_triggers() self._schedule, self._timezone = self._get_schedule() + self.kubernetes_labels = self._get_kubernetes_labels() self._workflow_template = self._compile_workflow_template() self._sensor = self._compile_sensor() @@ -201,6 +202,19 @@ def trigger(cls, name, parameters=None): except Exception as e: raise ArgoWorkflowsException(str(e)) + @staticmethod + def _get_kubernetes_labels(): + """ + Get Kubernetes labels from environment variable. + Parses the string into a dict and validates that values adhere to Kubernetes restrictions. + """ + if not KUBERNETES_LABELS: + return {} + env_labels = KUBERNETES_LABELS.split(",") + env_labels = parse_kube_keyvalue_list(env_labels, False) + validate_kube_labels(env_labels) + return env_labels + def _get_schedule(self): schedule = self.flow._flow_decorators.get("schedule") if schedule: @@ -474,11 +488,6 @@ def _compile_workflow_template(self): } ) - # get labels from env vars - env_labels = KUBERNETES_LABELS.split(",") if KUBERNETES_LABELS else [] - env_labels = parse_kube_keyvalue_list(env_labels, False) - validate_kube_labels(env_labels) - return ( WorkflowTemplate() .metadata( @@ -560,7 +569,7 @@ def _compile_workflow_template(self): .label("app.kubernetes.io/name", "metaflow-task") .label("app.kubernetes.io/part-of", "metaflow") .annotations(annotations) - .labels(env_labels) + .labels(self.kubernetes_labels) ) # Set the entrypoint to flow name .entrypoint(self.flow.name) @@ -1350,6 +1359,7 @@ def _compile_sensor(self): .namespace(KUBERNETES_NAMESPACE) .label("app.kubernetes.io/name", "metaflow-sensor") .label("app.kubernetes.io/part-of", "metaflow") + .labels(self.kubernetes_labels) .annotations(annotations) ) .spec(