Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add labels and fix argo #1360

Merged
merged 11 commits into from
May 23, 2023
2 changes: 2 additions & 0 deletions metaflow/metaflow_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,8 @@
KUBERNETES_NODE_SELECTOR = from_conf("KUBERNETES_NODE_SELECTOR", "")
KUBERNETES_TOLERATIONS = from_conf("KUBERNETES_TOLERATIONS", "")
KUBERNETES_SECRETS = from_conf("KUBERNETES_SECRETS", "")
# Default labels for kubernetes pods
KUBERNETES_LABELS = from_conf("KUBERNETES_LABELS", "")
# Default GPU vendor to use by K8S jobs created by Metaflow (supports nvidia, amd)
KUBERNETES_GPU_VENDOR = from_conf("KUBERNETES_GPU_VENDOR", "nvidia")
# Default container image for K8S
Expand Down
21 changes: 17 additions & 4 deletions metaflow/plugins/argo/argo_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
DEFAULT_METADATA,
DEFAULT_SECRETS_BACKEND_TYPE,
KUBERNETES_FETCH_EC2_METADATA,
KUBERNETES_LABELS,
KUBERNETES_NAMESPACE,
KUBERNETES_NODE_SELECTOR,
KUBERNETES_SANDBOX_INIT_SCRIPT,
Expand All @@ -40,6 +41,10 @@
)
from metaflow.mflog import BASH_SAVE_LOGS, bash_capture_logs, export_mflog_env_vars
from metaflow.parameters import deploy_time_eval
from metaflow.plugins.kubernetes.kubernetes import (
parse_kube_keyvalue_list,
validate_kube_labels,
)
from metaflow.util import (
compress_list,
dict_to_cli_options,
Expand Down Expand Up @@ -469,6 +474,11 @@ def _compile_workflow_template(self):
}
)

# get labels from env vars
dhpollack marked this conversation as resolved.
Show resolved Hide resolved
env_labels = KUBERNETES_LABELS.split(",")
env_labels = parse_kube_keyvalue_list(env_labels, False)
validate_kube_labels(env_labels)

return (
WorkflowTemplate()
.metadata(
Expand Down Expand Up @@ -550,6 +560,7 @@ def _compile_workflow_template(self):
.label("app.kubernetes.io/name", "metaflow-task")
.label("app.kubernetes.io/part-of", "metaflow")
.annotations(annotations)
.labels(env_labels)
)
# Set the entrypoint to flow name
.entrypoint(self.flow.name)
Expand Down Expand Up @@ -1102,15 +1113,16 @@ def _container_templates(self):
.retry_strategy(
times=total_retries,
minutes_between_retries=minutes_between_retries,
)
.metadata(
).metadata(
ObjectMeta().annotation("metaflow/step_name", node.name)
# Unfortunately, we can't set the task_id since it is generated
# inside the pod. However, it can be inferred from the annotation
# set by argo-workflows - `workflows.argoproj.io/outputs` - refer
# the field 'task-id' in 'parameters'
# .annotation("metaflow/task_id", ...)
.annotation("metaflow/attempt", retry_count)
# Set labels
.labels(resources.get("labels"))
dhpollack marked this conversation as resolved.
Show resolved Hide resolved
)
# Set emptyDir volume for state management
.empty_dir_volume("out")
Expand All @@ -1122,6 +1134,7 @@ def _container_templates(self):
)
# Set node selectors
.node_selectors(resources.get("node_selector"))
# Set tolerations
.tolerations(resources.get("tolerations"))
# Set container
.container(
Expand Down Expand Up @@ -1624,7 +1637,7 @@ def label(self, key, value):
def labels(self, labels):
if "labels" not in self.payload:
self.payload["labels"] = {}
self.payload["labels"].update(labels)
self.payload["labels"].update(labels or {})
saikonen marked this conversation as resolved.
Show resolved Hide resolved
return self

def name(self, name):
Expand Down Expand Up @@ -1733,7 +1746,7 @@ def label(self, key, value):
def labels(self, labels):
if "labels" not in self.payload:
self.payload["labels"] = {}
self.payload["labels"].update(labels)
self.payload["labels"].update(labels or {})
return self

def labels_from(self, labels_from):
Expand Down
64 changes: 63 additions & 1 deletion metaflow/plugins/kubernetes/kubernetes.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import json
import math
import os
import re
import shlex
import time
from typing import Dict, List, Optional

from metaflow import current, util
from metaflow.exception import MetaflowException
Expand All @@ -20,8 +22,9 @@
DEFAULT_AWS_CLIENT_PROVIDER,
DEFAULT_METADATA,
DEFAULT_SECRETS_BACKEND_TYPE,
KUBERNETES_SANDBOX_INIT_SCRIPT,
KUBERNETES_FETCH_EC2_METADATA,
KUBERNETES_LABELS,
KUBERNETES_SANDBOX_INIT_SCRIPT,
S3_ENDPOINT_URL,
SERVICE_HEADERS,
SERVICE_INTERNAL_URL,
Expand Down Expand Up @@ -157,6 +160,7 @@ def create_job(
run_time_limit=None,
env=None,
tolerations=None,
labels=None,
):
if env is None:
env = {}
Expand Down Expand Up @@ -189,6 +193,7 @@ def create_job(
retries=0,
step_name=step_name,
tolerations=tolerations,
labels=self._get_labels(labels),
use_tmpfs=use_tmpfs,
tmpfs_tempdir=tmpfs_tempdir,
tmpfs_size=tmpfs_size,
Expand Down Expand Up @@ -374,3 +379,60 @@ def wait_for_launch(job):
"stderr",
job_id=self._job.id,
)

@staticmethod
def _get_labels(extra_labels=None):
if extra_labels is None:
extra_labels = {}
env_labels = KUBERNETES_LABELS.split(",")
env_labels = parse_kube_keyvalue_list(env_labels, False)
labels = {**env_labels, **extra_labels}
validate_kube_labels(labels)
return labels


def validate_kube_labels(
dhpollack marked this conversation as resolved.
Show resolved Hide resolved
labels: Optional[Dict[str, Optional[str]]],
) -> bool:
"""Validate label values.

This validates the kubernetes label values. It does not validate the keys.
Ideally, keys should be static and also the validation rules for keys are
more complex than those for values. For full validation rules, see:

https://kubernetes.io/docs/concepts/overview/working-with-objects/labels/#syntax-and-character-set
"""

def validate_label(s: Optional[str]):
regex_match = r"^(([A-Za-z0-9][-A-Za-z0-9_.]{0,61})?[A-Za-z0-9])?$"
if not s:
# allow empty label
return True
if not re.search(regex_match, s):
raise KubernetesException(
'Invalid value: "%s"\n'
"A valid label must be an empty string or one that\n"
" - Consist of alphanumeric, '-', '_' or '.' characters\n"
" - Begins and ends with an alphanumeric character\n"
" - Is at most 63 characters" % s
)
return True

return all([validate_label(v) for v in labels.values()]) if labels else True


def parse_kube_keyvalue_list(items: List[str], requires_both: bool = True):
try:
ret = {}
for item_str in items:
item = item_str.split("=", 1)
if requires_both:
item[1] # raise IndexError
if str(item[0]) in ret:
raise KubernetesException("Duplicate key found: %s" % str(item[0]))
ret[str(item[0])] = str(item[1]) if len(item) > 1 else None
return ret
except KubernetesException as e:
raise e
except (AttributeError, IndexError):
raise KubernetesException("Unable to parse kubernetes list: %s" % items)
8 changes: 4 additions & 4 deletions metaflow/plugins/kubernetes/kubernetes_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
import time
import traceback

from metaflow import util, JSONTypeClass
from metaflow import JSONTypeClass, util
from metaflow._vendor import click
from metaflow.exception import METAFLOW_EXIT_DISALLOW_RETRY, CommandException
from metaflow.metadata.util import sync_local_metadata_from_datastore
from metaflow.metaflow_config import DATASTORE_LOCAL_DIR
from metaflow.metaflow_config import DATASTORE_LOCAL_DIR, KUBERNETES_LABELS
from metaflow.mflog import TASK_LOG_SOURCE

from .kubernetes import Kubernetes, KubernetesKilledException
from .kubernetes import Kubernetes, KubernetesKilledException, parse_kube_keyvalue_list
from .kubernetes_decorator import KubernetesDecorator


Expand Down Expand Up @@ -185,7 +185,7 @@ def echo(msg, stream="stderr", job_id=None, **kwargs):
stderr_location = ds.get_log_location(TASK_LOG_SOURCE, "stderr")

# `node_selector` is a tuple of strings, convert it to a dictionary
node_selector = KubernetesDecorator.parse_node_selector(node_selector)
node_selector = parse_kube_keyvalue_list(node_selector)

def _sync_metadata():
if ctx.obj.metadata.TYPE == "local":
Expand Down
36 changes: 14 additions & 22 deletions metaflow/plugins/kubernetes/kubernetes_decorator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import hashlib
import json
import os
import platform
import re
import sys
from typing import Dict, List, Optional, Union

from metaflow import current
from metaflow.decorators import StepDecorator
Expand All @@ -12,21 +15,21 @@
DATASTORE_LOCAL_DIR,
KUBERNETES_CONTAINER_IMAGE,
KUBERNETES_CONTAINER_REGISTRY,
KUBERNETES_FETCH_EC2_METADATA,
KUBERNETES_GPU_VENDOR,
KUBERNETES_LABELS,
KUBERNETES_NAMESPACE,
KUBERNETES_NODE_SELECTOR,
KUBERNETES_TOLERATIONS,
KUBERNETES_SERVICE_ACCOUNT,
KUBERNETES_SECRETS,
KUBERNETES_FETCH_EC2_METADATA,
KUBERNETES_SERVICE_ACCOUNT,
KUBERNETES_TOLERATIONS,
)
from metaflow.plugins.resources_decorator import ResourcesDecorator
from metaflow.plugins.timeout_decorator import get_run_time_limit_for_task
from metaflow.sidecar import Sidecar

from ..aws.aws_utils import get_docker_registry, get_ec2_instance_metadata

from .kubernetes import KubernetesException
from .kubernetes import KubernetesException, parse_kube_keyvalue_list

try:
unicode
Expand Down Expand Up @@ -116,7 +119,7 @@ def __init__(self, attributes=None, statically_defined=False):
self.attributes["tolerations"] = json.loads(KUBERNETES_TOLERATIONS)

if isinstance(self.attributes["node_selector"], str):
self.attributes["node_selector"] = self.parse_node_selector(
self.attributes["node_selector"] = parse_kube_keyvalue_list(
self.attributes["node_selector"].split(",")
)

Expand Down Expand Up @@ -313,10 +316,11 @@ def runtime_step_cli(
for k, v in self.attributes.items():
if k == "namespace":
cli_args.command_options["k8s_namespace"] = v
elif k == "node_selector" and v:
cli_args.command_options[k] = ",".join(
["=".join([key, str(val)]) for key, val in v.items()]
)
elif k in {"node_selector"} and v:
cli_args.command_options[k] = [
"=".join([key, str(val)]) if val else key
for key, val in v.items()
]
elif k == "tolerations":
cli_args.command_options[k] = json.dumps(v)
else:
Expand Down Expand Up @@ -428,15 +432,3 @@ def _save_package_once(cls, flow_datastore, package):
cls.package_url, cls.package_sha = flow_datastore.save_data(
[package.blob], len_hint=1
)[0]

@staticmethod
def parse_node_selector(node_selector: list):
try:
return {
str(k.split("=", 1)[0]): str(k.split("=", 1)[1])
for k in node_selector or []
}
except (AttributeError, IndexError):
raise KubernetesException(
"Unable to parse node_selector: %s" % node_selector
)
97 changes: 97 additions & 0 deletions test/unit/test_kubernetes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import pytest

from metaflow.plugins.kubernetes.kubernetes import (
KubernetesException,
validate_kube_labels,
parse_kube_keyvalue_list,
)


@pytest.mark.parametrize(
"labels",
[
None,
{"label": "value"},
{"label1": "val1", "label2": "val2"},
{"label1": "val1", "label2": None},
{"label": "a"},
{"label": ""},
{
"label": (
"1234567890"
"1234567890"
"1234567890"
"1234567890"
"1234567890"
"1234567890"
"123"
)
},
{
"label": (
"1234567890"
"1234567890"
"1234-_.890"
"1234567890"
"1234567890"
"1234567890"
"123"
)
},
],
)
def test_kubernetes_decorator_validate_kube_labels(labels):
assert validate_kube_labels(labels)


@pytest.mark.parametrize(
"labels",
[
{"label": "a-"},
{"label": ".a"},
{"label": "test()"},
{
"label": (
"1234567890"
"1234567890"
"1234567890"
"1234567890"
"1234567890"
"1234567890"
"1234"
)
},
{"label": "(){}??"},
{"valid": "test", "invalid": "bißchen"},
],
)
def test_kubernetes_decorator_validate_kube_labels_fail(labels):
"""Fail if label contains invalid characters or is too long"""
with pytest.raises(KubernetesException):
validate_kube_labels(labels)


@pytest.mark.parametrize(
"items,requires_both,expected",
[
(["key=value"], True, {"key": "value"}),
(["key=value"], False, {"key": "value"}),
(["key"], False, {"key": None}),
(["key=value", "key2=value2"], True, {"key": "value", "key2": "value2"}),
],
)
def test_kubernetes_parse_keyvalue_list(items, requires_both, expected):
ret = parse_kube_keyvalue_list(items, requires_both)
assert ret == expected


@pytest.mark.parametrize(
"items,requires_both",
[
(["key=value", "key=value2"], True),
(["key"], True),
],
)
def test_kubernetes_parse_keyvalue_list(items, requires_both):
with pytest.raises(KubernetesException):
parse_kube_keyvalue_list(items, requires_both)