diff --git a/metaflow/metaflow_config.py b/metaflow/metaflow_config.py index 0f0b5198a8f..60cbdd6aef6 100644 --- a/metaflow/metaflow_config.py +++ b/metaflow/metaflow_config.py @@ -283,6 +283,8 @@ KUBERNETES_SECRETS = from_conf("KUBERNETES_SECRETS", "") # Default labels for kubernetes pods KUBERNETES_LABELS = from_conf("KUBERNETES_LABELS", "") +# Default annotations for kubernetes pods +KUBERNETES_ANNOTATIONS = from_conf("KUBERNETES_ANNOTATIONS", "") # Default GPU vendor to use by K8S jobs created by Metaflow (supports nvidia, amd) KUBERNETES_GPU_VENDOR = from_conf("KUBERNETES_GPU_VENDOR", "nvidia") # Default container image for K8S diff --git a/metaflow/plugins/argo/argo_workflows.py b/metaflow/plugins/argo/argo_workflows.py index 2873dfb7909..3c516b7bcf9 100644 --- a/metaflow/plugins/argo/argo_workflows.py +++ b/metaflow/plugins/argo/argo_workflows.py @@ -30,7 +30,6 @@ DEFAULT_METADATA, DEFAULT_SECRETS_BACKEND_TYPE, KUBERNETES_FETCH_EC2_METADATA, - KUBERNETES_LABELS, KUBERNETES_NAMESPACE, KUBERNETES_NODE_SELECTOR, KUBERNETES_SANDBOX_INIT_SCRIPT, @@ -44,10 +43,6 @@ ) from metaflow.mflog import BASH_SAVE_LOGS, bash_capture_logs, export_mflog_env_vars from metaflow.parameters import deploy_time_eval -from metaflow.plugins.kubernetes.kubernetes import ( - parse_kube_keyvalue_list, - validate_kube_labels, -) from metaflow.util import ( compress_list, dict_to_cli_options, @@ -157,6 +152,7 @@ def __init__( self._schedule, self._timezone = self._get_schedule() self.kubernetes_labels = self._get_kubernetes_labels() + self.kubernetes_annotations = self._get_kubernetes_annotations() self._workflow_template = self._compile_workflow_template() self._sensor = self._compile_sensor() @@ -266,18 +262,49 @@ def trigger(cls, name, parameters=None): except Exception as e: raise ArgoWorkflowsException(str(e)) - @staticmethod - def _get_kubernetes_labels(): + def _get_kubernetes_labels(self): """ - Get Kubernetes labels from environment variable. - Parses the string into a dict and validates that values adhere to Kubernetes restrictions. + Get Kubernetes labels from the start step decorator. """ - if not KUBERNETES_LABELS: - return {} - env_labels = KUBERNETES_LABELS.split(",") - env_labels = parse_kube_keyvalue_list(env_labels, False) - validate_kube_labels(env_labels) - return env_labels + + resources = dict( + [ + deco + for node in self.graph + if node.name == "start" + for deco in node.decorators + if deco.name == "kubernetes" + ][0].attributes + ) + return resources["labels"] or {} + + def _get_kubernetes_annotations(self): + """ + Get Kubernetes annotations from the start step decorator. Append Argo specific annotations. + """ + resources = dict( + [ + deco + for node in self.graph + if node.name == "start" + for deco in node.decorators + if deco.name == "kubernetes" + ][0].attributes + ) + annotations = {} + if resources["annotations"] is not None: + # make a copy so we do not mess possible start-step specific annotations. + annotations = resources["annotations"].copy() + + annotations.update( + { + "metaflow/production_token": self.production_token, + "metaflow/owner": self.username, + "metaflow/user": "argo-workflows", + "metaflow/flow_name": self.flow.name, + } + ) + return annotations def _get_schedule(self): schedule = self.flow._flow_decorators.get("schedule") @@ -568,21 +595,6 @@ def _compile_workflow_template(self): # generate container templates at the top level (in WorkflowSpec) and maintain # references to them within the DAGTask. - annotations = { - "metaflow/production_token": self.production_token, - "metaflow/owner": self.username, - "metaflow/user": "argo-workflows", - "metaflow/flow_name": self.flow.name, - } - if current.get("project_name"): - annotations.update( - { - "metaflow/project_name": current.project_name, - "metaflow/branch_name": current.branch_name, - "metaflow/project_flow_name": current.project_flow_name, - } - ) - return ( WorkflowTemplate() .metadata( @@ -595,7 +607,7 @@ def _compile_workflow_template(self): .namespace(KUBERNETES_NAMESPACE) .label("app.kubernetes.io/name", "metaflow-flow") .label("app.kubernetes.io/part-of", "metaflow") - .annotations(annotations) + .annotations(self.kubernetes_annotations) ) .spec( WorkflowSpec() @@ -628,7 +640,10 @@ def _compile_workflow_template(self): .label("app.kubernetes.io/name", "metaflow-run") .label("app.kubernetes.io/part-of", "metaflow") .annotations( - {**annotations, **{"metaflow/run_id": "argo-{{workflow.name}}"}} + { + **self.kubernetes_annotations, + **{"metaflow/run_id": "argo-{{workflow.name}}"}, + } ) # TODO: Set dynamic labels using labels_from. Ideally, we would # want to expose run_id as a label. It's easy to add labels, @@ -661,10 +676,10 @@ def _compile_workflow_template(self): # Set common pod metadata. .pod_metadata( Metadata() + .labels(self.kubernetes_labels) .label("app.kubernetes.io/name", "metaflow-task") .label("app.kubernetes.io/part-of", "metaflow") - .annotations(annotations) - .labels(self.kubernetes_labels) + .annotations(self.kubernetes_annotations) ) # Set the entrypoint to flow name .entrypoint(self.flow.name) @@ -1290,13 +1305,15 @@ def _container_templates(self): minutes_between_retries=minutes_between_retries, ) .metadata( - ObjectMeta().annotation("metaflow/step_name", node.name) + ObjectMeta() + .annotation("metaflow/step_name", node.name) # Unfortunately, we can't set the task_id since it is generated # inside the pod. However, it can be inferred from the annotation # set by argo-workflows - `workflows.argoproj.io/outputs` - refer # the field 'task-id' in 'parameters' # .annotation("metaflow/task_id", ...) .annotation("metaflow/attempt", retry_count) + .labels(resources["labels"]) ) # Set emptyDir volume for state management .empty_dir_volume("out") @@ -1659,23 +1676,6 @@ def _compile_sensor(self): "sdk (https://pypi.org/project/kubernetes/) first." ) - labels = {"app.kubernetes.io/part-of": "metaflow"} - - annotations = { - "metaflow/production_token": self.production_token, - "metaflow/owner": self.username, - "metaflow/user": "argo-workflows", - "metaflow/flow_name": self.flow.name, - } - if current.get("project_name"): - annotations.update( - { - "metaflow/project_name": current.project_name, - "metaflow/branch_name": current.branch_name, - "metaflow/project_flow_name": current.project_flow_name, - } - ) - return ( Sensor() .metadata( @@ -1683,10 +1683,10 @@ def _compile_sensor(self): ObjectMeta() .name(self.name.replace(".", "-")) .namespace(KUBERNETES_NAMESPACE) + .labels(self.kubernetes_labels) .label("app.kubernetes.io/name", "metaflow-sensor") .label("app.kubernetes.io/part-of", "metaflow") - .labels(self.kubernetes_labels) - .annotations(annotations) + .annotations(self.kubernetes_annotations) ) .spec( SensorSpec().template( @@ -1696,7 +1696,7 @@ def _compile_sensor(self): ObjectMeta() .label("app.kubernetes.io/name", "metaflow-sensor") .label("app.kubernetes.io/part-of", "metaflow") - .annotations(annotations) + .annotations(self.kubernetes_annotations) ) .container( # Run sensor in guaranteed QoS. The sensor isn't doing a lot diff --git a/metaflow/plugins/kubernetes/kubernetes.py b/metaflow/plugins/kubernetes/kubernetes.py index fa64bdf60a2..dc8e448adf4 100644 --- a/metaflow/plugins/kubernetes/kubernetes.py +++ b/metaflow/plugins/kubernetes/kubernetes.py @@ -27,7 +27,6 @@ DEFAULT_METADATA, DEFAULT_SECRETS_BACKEND_TYPE, KUBERNETES_FETCH_EC2_METADATA, - KUBERNETES_LABELS, KUBERNETES_SANDBOX_INIT_SCRIPT, S3_ENDPOINT_URL, SERVICE_HEADERS, @@ -168,6 +167,7 @@ def create_job( persistent_volume_claims=None, tolerations=None, labels=None, + annotations=None, ): if env is None: env = {} @@ -201,7 +201,8 @@ def create_job( retries=0, step_name=step_name, tolerations=tolerations, - labels=self._get_labels(labels), + labels=labels, + annotations=annotations, use_tmpfs=use_tmpfs, tmpfs_tempdir=tmpfs_tempdir, tmpfs_size=tmpfs_size, @@ -286,31 +287,17 @@ def create_job( for name, value in env.items(): job.environment_variable(name, value) + # Add job specific annotations not set in the decorator. annotations = { - "metaflow/user": user, - "metaflow/flow_name": flow_name, + "metaflow/run_id": run_id, + "metaflow/step_name": step_name, + "metaflow/task_id": task_id, + "metaflow/attempt": attempt, } - if current.get("project_name"): - annotations.update( - { - "metaflow/project_name": current.project_name, - "metaflow/branch_name": current.branch_name, - "metaflow/project_flow_name": current.project_flow_name, - } - ) for name, value in annotations.items(): job.annotation(name, value) - ( - job.annotation("metaflow/run_id", run_id) - .annotation("metaflow/step_name", step_name) - .annotation("metaflow/task_id", task_id) - .annotation("metaflow/attempt", attempt) - .label("app.kubernetes.io/name", "metaflow-task") - .label("app.kubernetes.io/part-of", "metaflow") - ) - return job.create() def wait(self, stdout_location, stderr_location, echo=None): @@ -406,46 +393,6 @@ def wait_for_launch(job): job_id=self._job.id, ) - @staticmethod - def _get_labels(extra_labels=None): - if extra_labels is None: - extra_labels = {} - env_labels = KUBERNETES_LABELS.split(",") if KUBERNETES_LABELS else [] - env_labels = parse_kube_keyvalue_list(env_labels, False) - labels = {**env_labels, **extra_labels} - validate_kube_labels(labels) - return labels - - -def validate_kube_labels( - labels: Optional[Dict[str, Optional[str]]], -) -> bool: - """Validate label values. - - This validates the kubernetes label values. It does not validate the keys. - Ideally, keys should be static and also the validation rules for keys are - more complex than those for values. For full validation rules, see: - - https://kubernetes.io/docs/concepts/overview/working-with-objects/labels/#syntax-and-character-set - """ - - def validate_label(s: Optional[str]): - regex_match = r"^(([A-Za-z0-9][-A-Za-z0-9_.]{0,61})?[A-Za-z0-9])?$" - if not s: - # allow empty label - return True - if not re.search(regex_match, s): - raise KubernetesException( - 'Invalid value: "%s"\n' - "A valid label must be an empty string or one that\n" - " - Consist of alphanumeric, '-', '_' or '.' characters\n" - " - Begins and ends with an alphanumeric character\n" - " - Is at most 63 characters" % s - ) - return True - - return all([validate_label(v) for v in labels.values()]) if labels else True - def parse_kube_keyvalue_list(items: List[str], requires_both: bool = True): try: diff --git a/metaflow/plugins/kubernetes/kubernetes_cli.py b/metaflow/plugins/kubernetes/kubernetes_cli.py index 23bc2d3e601..d1817d21515 100644 --- a/metaflow/plugins/kubernetes/kubernetes_cli.py +++ b/metaflow/plugins/kubernetes/kubernetes_cli.py @@ -7,7 +7,7 @@ from metaflow._vendor import click from metaflow.exception import METAFLOW_EXIT_DISALLOW_RETRY, CommandException from metaflow.metadata.util import sync_local_metadata_from_datastore -from metaflow.metaflow_config import DATASTORE_LOCAL_DIR, KUBERNETES_LABELS +from metaflow.metaflow_config import DATASTORE_LOCAL_DIR from metaflow.mflog import TASK_LOG_SOURCE from .kubernetes import Kubernetes, KubernetesKilledException, parse_kube_keyvalue_list @@ -105,6 +105,18 @@ def kubernetes(): type=JSONTypeClass(), multiple=False, ) +@click.option( + "--labels", + default=None, + type=JSONTypeClass(), + multiple=False, +) +@click.option( + "--annotations", + default=None, + type=JSONTypeClass(), + multiple=False, +) @click.pass_context def step( ctx, @@ -130,6 +142,8 @@ def step( run_time_limit=None, persistent_volume_claims=None, tolerations=None, + labels=None, + annotations=None, **kwargs ): def echo(msg, stream="stderr", job_id=None, **kwargs): @@ -244,6 +258,8 @@ def _sync_metadata(): env=env, persistent_volume_claims=persistent_volume_claims, tolerations=tolerations, + labels=labels, + annotations=annotations, ) except Exception as e: traceback.print_exc(chain=False) diff --git a/metaflow/plugins/kubernetes/kubernetes_decorator.py b/metaflow/plugins/kubernetes/kubernetes_decorator.py index 4dcd7fd640a..9c7ad3e1a48 100644 --- a/metaflow/plugins/kubernetes/kubernetes_decorator.py +++ b/metaflow/plugins/kubernetes/kubernetes_decorator.py @@ -2,6 +2,8 @@ import os import platform import sys +import re +from typing import Optional, Dict from metaflow import current from metaflow.decorators import StepDecorator @@ -15,6 +17,8 @@ KUBERNETES_FETCH_EC2_METADATA, KUBERNETES_IMAGE_PULL_POLICY, KUBERNETES_GPU_VENDOR, + KUBERNETES_LABELS, + KUBERNETES_ANNOTATIONS, KUBERNETES_NAMESPACE, KUBERNETES_NODE_SELECTOR, KUBERNETES_PERSISTENT_VOLUME_CLAIMS, @@ -66,6 +70,10 @@ class KubernetesDecorator(StepDecorator): in Metaflow configuration. tolerations : List[str], default: METAFLOW_KUBERNETES_TOLERATIONS Kubernetes tolerations to use when launching pod in Kubernetes. + labels: Dict[str, str], default: METAFLOW_KUBERNETES_LABELS + Kubernetes labels to use when launching pod in Kubernetes. + annotations: Dict[str, str], default: METAFLOW_KUBERNETES_ANNOTATIONS + Kubernetes annotations to use when launching pod in Kubernetes. use_tmpfs: bool, default: False This enables an explicit tmpfs mount for this step. tmpfs_tempdir: bool, default: True @@ -96,6 +104,8 @@ class KubernetesDecorator(StepDecorator): "gpu_vendor": None, "tolerations": None, # e.g., [{"key": "arch", "operator": "Equal", "value": "amd"}, # {"key": "foo", "operator": "Equal", "value": "bar"}] + "labels": None, # e.g. {"test-label": "value", "another-label":"value2"} + "annotations": None, # e.g. {"note": "value", "another-note": "value2"} "use_tmpfs": None, "tmpfs_tempdir": True, "tmpfs_size": None, @@ -158,6 +168,60 @@ def __init__(self, attributes=None, statically_defined=False): except (NameError, ImportError): pass + # Label source precedence (decreasing): + # - System labels + # - Decorator labels: @kubernetes(labels={}) + # - Environment variable labels: METAFLOW_KUBERNETES_LABELS= + deco_labels = {} + if self.attributes["labels"] is not None: + deco_labels = self.attributes["labels"] + + env_labels = {} + if KUBERNETES_LABELS: + env_labels = parse_kube_keyvalue_list(KUBERNETES_LABELS.split(","), False) + + system_labels = { + "app.kubernetes.io/name": "metaflow-task", + "app.kubernetes.io/part-of": "metaflow", + } + + self.attributes["labels"] = {**env_labels, **deco_labels, **system_labels} + + # Annotations + # annotation precedence (decreasing): + # - System annotations + # - Decorator annotations + # - Environment annotations + deco_annotations = {} + if self.attributes["annotations"] is not None: + deco_annotations = self.attributes["annotations"] + + env_annotations = {} + if KUBERNETES_ANNOTATIONS: + env_annotations = parse_kube_keyvalue_list( + KUBERNETES_ANNOTATIONS.split(","), False + ) + + system_annotations = { + "metaflow/user": current.username, + "metaflow/flow_name": current.flow_name, + } + + if current.get("project_name"): + system_annotations.update( + { + "metaflow/project_name": current.project_name, + "metaflow/branch_name": current.branch_name, + "metaflow/project_flow_name": current.project_flow_name, + } + ) + + self.attributes["annotations"] = { + **env_annotations, + **deco_annotations, + **system_annotations, + } + # If no docker image is explicitly specified, impute a default image. if not self.attributes["image"]: # If metaflow-config specifies a docker image, just use that. @@ -281,6 +345,9 @@ def step_init(self, flow, graph, step, decos, environment, flow_datastore, logge ) ) + validate_kube_labels_or_annotations(self.attributes["labels"]) + validate_kube_labels_or_annotations(self.attributes["annotations"]) + def package_init(self, flow, step_name, environment): try: # Kubernetes is a soft dependency. @@ -332,7 +399,12 @@ def runtime_step_cli( "=".join([key, str(val)]) if val else key for key, val in v.items() ] - elif k in ["tolerations", "persistent_volume_claims"]: + elif k in [ + "tolerations", + "persistent_volume_claims", + "labels", + "annotations", + ]: cli_args.command_options[k] = json.dumps(v) else: cli_args.command_options[k] = v @@ -443,3 +515,33 @@ def _save_package_once(cls, flow_datastore, package): cls.package_url, cls.package_sha = flow_datastore.save_data( [package.blob], len_hint=1 )[0] + + +def validate_kube_labels_or_annotations( + labels: Optional[Dict[str, Optional[str]]], +) -> bool: + """Validate label values. + + This validates the kubernetes label values. It does not validate the keys. + Ideally, keys should be static and also the validation rules for keys are + more complex than those for values. For full validation rules, see: + + https://kubernetes.io/docs/concepts/overview/working-with-objects/labels/#syntax-and-character-set + """ + + def validate_label(s: Optional[str]): + regex_match = r"^(([A-Za-z0-9][-A-Za-z0-9_.]{0,61})?[A-Za-z0-9])?$" + if not s: + # allow empty label + return True + if not re.search(regex_match, s): + raise KubernetesException( + 'Invalid value: "%s"\n' + "A valid label must be an empty string or one that\n" + " - Consist of alphanumeric, '-', '_' or '.' characters\n" + " - Begins and ends with an alphanumeric character\n" + " - Is at most 63 characters" % s + ) + return True + + return all([validate_label(v) for v in labels.values()]) if labels else True diff --git a/test/unit/test_kubernetes.py b/test/unit/test_kubernetes.py index 2169bfd4e0b..607394b2b3d 100644 --- a/test/unit/test_kubernetes.py +++ b/test/unit/test_kubernetes.py @@ -2,10 +2,13 @@ from metaflow.plugins.kubernetes.kubernetes import ( KubernetesException, - validate_kube_labels, parse_kube_keyvalue_list, ) +from metaflow.plugins.kubernetes.kubernetes_decorator import ( + validate_kube_labels_or_annotations, +) + @pytest.mark.parametrize( "labels", @@ -40,8 +43,8 @@ }, ], ) -def test_kubernetes_decorator_validate_kube_labels(labels): - assert validate_kube_labels(labels) +def test_kubernetes_decorator_validate_kube_labels_or_annotations(labels): + assert validate_kube_labels_or_annotations(labels) @pytest.mark.parametrize( @@ -65,10 +68,10 @@ def test_kubernetes_decorator_validate_kube_labels(labels): {"valid": "test", "invalid": "bißchen"}, ], ) -def test_kubernetes_decorator_validate_kube_labels_fail(labels): +def test_kubernetes_decorator_validate_kube_labels_or_annotations_fail(labels): """Fail if label contains invalid characters or is too long""" with pytest.raises(KubernetesException): - validate_kube_labels(labels) + validate_kube_labels_or_annotations(labels) @pytest.mark.parametrize(