Skip to content

Commit

Permalink
Adds custom annotations via env variables (#1442)
Browse files Browse the repository at this point in the history
* test annotations

* add configurable annotations

* correct function name

* make extra annotations update existing ones

* correct logic for annotation flow

* Add custom annotations for argo workflows

* remove unused label/annotation parameters, clean up functions

* add annotations to sensors

* undo reformatting of _get_annotations

* rework kubernetes labels support. Also support labels as a dict through decorator now

* rework label precedence

* label precedence fix for sensors

* rename kube validator

* refactor Kubernetes annotations to the kube decorator

* clean up setting annotations in argo workflows

---------

Co-authored-by: tylerpotts <tyler.potts@dtn.com>
Co-authored-by: Sakari Ikonen <sakari.a.ikonen@gmail.com>
  • Loading branch information
3 people authored Jul 27, 2023
1 parent ec8fd6c commit f8861b8
Show file tree
Hide file tree
Showing 6 changed files with 193 additions and 123 deletions.
2 changes: 2 additions & 0 deletions metaflow/metaflow_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,8 @@
KUBERNETES_SECRETS = from_conf("KUBERNETES_SECRETS", "")
# Default labels for kubernetes pods
KUBERNETES_LABELS = from_conf("KUBERNETES_LABELS", "")
# Default annotations for kubernetes pods
KUBERNETES_ANNOTATIONS = from_conf("KUBERNETES_ANNOTATIONS", "")
# Default GPU vendor to use by K8S jobs created by Metaflow (supports nvidia, amd)
KUBERNETES_GPU_VENDOR = from_conf("KUBERNETES_GPU_VENDOR", "nvidia")
# Default container image for K8S
Expand Down
110 changes: 55 additions & 55 deletions metaflow/plugins/argo/argo_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
DEFAULT_METADATA,
DEFAULT_SECRETS_BACKEND_TYPE,
KUBERNETES_FETCH_EC2_METADATA,
KUBERNETES_LABELS,
KUBERNETES_NAMESPACE,
KUBERNETES_NODE_SELECTOR,
KUBERNETES_SANDBOX_INIT_SCRIPT,
Expand All @@ -44,10 +43,6 @@
)
from metaflow.mflog import BASH_SAVE_LOGS, bash_capture_logs, export_mflog_env_vars
from metaflow.parameters import deploy_time_eval
from metaflow.plugins.kubernetes.kubernetes import (
parse_kube_keyvalue_list,
validate_kube_labels,
)
from metaflow.util import (
compress_list,
dict_to_cli_options,
Expand Down Expand Up @@ -157,6 +152,7 @@ def __init__(
self._schedule, self._timezone = self._get_schedule()

self.kubernetes_labels = self._get_kubernetes_labels()
self.kubernetes_annotations = self._get_kubernetes_annotations()
self._workflow_template = self._compile_workflow_template()
self._sensor = self._compile_sensor()

Expand Down Expand Up @@ -266,18 +262,49 @@ def trigger(cls, name, parameters=None):
except Exception as e:
raise ArgoWorkflowsException(str(e))

@staticmethod
def _get_kubernetes_labels():
def _get_kubernetes_labels(self):
"""
Get Kubernetes labels from environment variable.
Parses the string into a dict and validates that values adhere to Kubernetes restrictions.
Get Kubernetes labels from the start step decorator.
"""
if not KUBERNETES_LABELS:
return {}
env_labels = KUBERNETES_LABELS.split(",")
env_labels = parse_kube_keyvalue_list(env_labels, False)
validate_kube_labels(env_labels)
return env_labels

resources = dict(
[
deco
for node in self.graph
if node.name == "start"
for deco in node.decorators
if deco.name == "kubernetes"
][0].attributes
)
return resources["labels"] or {}

def _get_kubernetes_annotations(self):
"""
Get Kubernetes annotations from the start step decorator. Append Argo specific annotations.
"""
resources = dict(
[
deco
for node in self.graph
if node.name == "start"
for deco in node.decorators
if deco.name == "kubernetes"
][0].attributes
)
annotations = {}
if resources["annotations"] is not None:
# make a copy so we do not mess possible start-step specific annotations.
annotations = resources["annotations"].copy()

annotations.update(
{
"metaflow/production_token": self.production_token,
"metaflow/owner": self.username,
"metaflow/user": "argo-workflows",
"metaflow/flow_name": self.flow.name,
}
)
return annotations

def _get_schedule(self):
schedule = self.flow._flow_decorators.get("schedule")
Expand Down Expand Up @@ -568,21 +595,6 @@ def _compile_workflow_template(self):
# generate container templates at the top level (in WorkflowSpec) and maintain
# references to them within the DAGTask.

annotations = {
"metaflow/production_token": self.production_token,
"metaflow/owner": self.username,
"metaflow/user": "argo-workflows",
"metaflow/flow_name": self.flow.name,
}
if current.get("project_name"):
annotations.update(
{
"metaflow/project_name": current.project_name,
"metaflow/branch_name": current.branch_name,
"metaflow/project_flow_name": current.project_flow_name,
}
)

return (
WorkflowTemplate()
.metadata(
Expand All @@ -595,7 +607,7 @@ def _compile_workflow_template(self):
.namespace(KUBERNETES_NAMESPACE)
.label("app.kubernetes.io/name", "metaflow-flow")
.label("app.kubernetes.io/part-of", "metaflow")
.annotations(annotations)
.annotations(self.kubernetes_annotations)
)
.spec(
WorkflowSpec()
Expand Down Expand Up @@ -628,7 +640,10 @@ def _compile_workflow_template(self):
.label("app.kubernetes.io/name", "metaflow-run")
.label("app.kubernetes.io/part-of", "metaflow")
.annotations(
{**annotations, **{"metaflow/run_id": "argo-{{workflow.name}}"}}
{
**self.kubernetes_annotations,
**{"metaflow/run_id": "argo-{{workflow.name}}"},
}
)
# TODO: Set dynamic labels using labels_from. Ideally, we would
# want to expose run_id as a label. It's easy to add labels,
Expand Down Expand Up @@ -661,10 +676,10 @@ def _compile_workflow_template(self):
# Set common pod metadata.
.pod_metadata(
Metadata()
.labels(self.kubernetes_labels)
.label("app.kubernetes.io/name", "metaflow-task")
.label("app.kubernetes.io/part-of", "metaflow")
.annotations(annotations)
.labels(self.kubernetes_labels)
.annotations(self.kubernetes_annotations)
)
# Set the entrypoint to flow name
.entrypoint(self.flow.name)
Expand Down Expand Up @@ -1290,13 +1305,15 @@ def _container_templates(self):
minutes_between_retries=minutes_between_retries,
)
.metadata(
ObjectMeta().annotation("metaflow/step_name", node.name)
ObjectMeta()
.annotation("metaflow/step_name", node.name)
# Unfortunately, we can't set the task_id since it is generated
# inside the pod. However, it can be inferred from the annotation
# set by argo-workflows - `workflows.argoproj.io/outputs` - refer
# the field 'task-id' in 'parameters'
# .annotation("metaflow/task_id", ...)
.annotation("metaflow/attempt", retry_count)
.labels(resources["labels"])
)
# Set emptyDir volume for state management
.empty_dir_volume("out")
Expand Down Expand Up @@ -1659,34 +1676,17 @@ def _compile_sensor(self):
"sdk (https://pypi.org/project/kubernetes/) first."
)

labels = {"app.kubernetes.io/part-of": "metaflow"}

annotations = {
"metaflow/production_token": self.production_token,
"metaflow/owner": self.username,
"metaflow/user": "argo-workflows",
"metaflow/flow_name": self.flow.name,
}
if current.get("project_name"):
annotations.update(
{
"metaflow/project_name": current.project_name,
"metaflow/branch_name": current.branch_name,
"metaflow/project_flow_name": current.project_flow_name,
}
)

return (
Sensor()
.metadata(
# Sensor metadata.
ObjectMeta()
.name(self.name.replace(".", "-"))
.namespace(KUBERNETES_NAMESPACE)
.labels(self.kubernetes_labels)
.label("app.kubernetes.io/name", "metaflow-sensor")
.label("app.kubernetes.io/part-of", "metaflow")
.labels(self.kubernetes_labels)
.annotations(annotations)
.annotations(self.kubernetes_annotations)
)
.spec(
SensorSpec().template(
Expand All @@ -1696,7 +1696,7 @@ def _compile_sensor(self):
ObjectMeta()
.label("app.kubernetes.io/name", "metaflow-sensor")
.label("app.kubernetes.io/part-of", "metaflow")
.annotations(annotations)
.annotations(self.kubernetes_annotations)
)
.container(
# Run sensor in guaranteed QoS. The sensor isn't doing a lot
Expand Down
69 changes: 8 additions & 61 deletions metaflow/plugins/kubernetes/kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
DEFAULT_METADATA,
DEFAULT_SECRETS_BACKEND_TYPE,
KUBERNETES_FETCH_EC2_METADATA,
KUBERNETES_LABELS,
KUBERNETES_SANDBOX_INIT_SCRIPT,
S3_ENDPOINT_URL,
SERVICE_HEADERS,
Expand Down Expand Up @@ -168,6 +167,7 @@ def create_job(
persistent_volume_claims=None,
tolerations=None,
labels=None,
annotations=None,
):
if env is None:
env = {}
Expand Down Expand Up @@ -201,7 +201,8 @@ def create_job(
retries=0,
step_name=step_name,
tolerations=tolerations,
labels=self._get_labels(labels),
labels=labels,
annotations=annotations,
use_tmpfs=use_tmpfs,
tmpfs_tempdir=tmpfs_tempdir,
tmpfs_size=tmpfs_size,
Expand Down Expand Up @@ -286,31 +287,17 @@ def create_job(
for name, value in env.items():
job.environment_variable(name, value)

# Add job specific annotations not set in the decorator.
annotations = {
"metaflow/user": user,
"metaflow/flow_name": flow_name,
"metaflow/run_id": run_id,
"metaflow/step_name": step_name,
"metaflow/task_id": task_id,
"metaflow/attempt": attempt,
}
if current.get("project_name"):
annotations.update(
{
"metaflow/project_name": current.project_name,
"metaflow/branch_name": current.branch_name,
"metaflow/project_flow_name": current.project_flow_name,
}
)

for name, value in annotations.items():
job.annotation(name, value)

(
job.annotation("metaflow/run_id", run_id)
.annotation("metaflow/step_name", step_name)
.annotation("metaflow/task_id", task_id)
.annotation("metaflow/attempt", attempt)
.label("app.kubernetes.io/name", "metaflow-task")
.label("app.kubernetes.io/part-of", "metaflow")
)

return job.create()

def wait(self, stdout_location, stderr_location, echo=None):
Expand Down Expand Up @@ -406,46 +393,6 @@ def wait_for_launch(job):
job_id=self._job.id,
)

@staticmethod
def _get_labels(extra_labels=None):
if extra_labels is None:
extra_labels = {}
env_labels = KUBERNETES_LABELS.split(",") if KUBERNETES_LABELS else []
env_labels = parse_kube_keyvalue_list(env_labels, False)
labels = {**env_labels, **extra_labels}
validate_kube_labels(labels)
return labels


def validate_kube_labels(
labels: Optional[Dict[str, Optional[str]]],
) -> bool:
"""Validate label values.
This validates the kubernetes label values. It does not validate the keys.
Ideally, keys should be static and also the validation rules for keys are
more complex than those for values. For full validation rules, see:
https://kubernetes.io/docs/concepts/overview/working-with-objects/labels/#syntax-and-character-set
"""

def validate_label(s: Optional[str]):
regex_match = r"^(([A-Za-z0-9][-A-Za-z0-9_.]{0,61})?[A-Za-z0-9])?$"
if not s:
# allow empty label
return True
if not re.search(regex_match, s):
raise KubernetesException(
'Invalid value: "%s"\n'
"A valid label must be an empty string or one that\n"
" - Consist of alphanumeric, '-', '_' or '.' characters\n"
" - Begins and ends with an alphanumeric character\n"
" - Is at most 63 characters" % s
)
return True

return all([validate_label(v) for v in labels.values()]) if labels else True


def parse_kube_keyvalue_list(items: List[str], requires_both: bool = True):
try:
Expand Down
18 changes: 17 additions & 1 deletion metaflow/plugins/kubernetes/kubernetes_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from metaflow._vendor import click
from metaflow.exception import METAFLOW_EXIT_DISALLOW_RETRY, CommandException
from metaflow.metadata.util import sync_local_metadata_from_datastore
from metaflow.metaflow_config import DATASTORE_LOCAL_DIR, KUBERNETES_LABELS
from metaflow.metaflow_config import DATASTORE_LOCAL_DIR
from metaflow.mflog import TASK_LOG_SOURCE

from .kubernetes import Kubernetes, KubernetesKilledException, parse_kube_keyvalue_list
Expand Down Expand Up @@ -105,6 +105,18 @@ def kubernetes():
type=JSONTypeClass(),
multiple=False,
)
@click.option(
"--labels",
default=None,
type=JSONTypeClass(),
multiple=False,
)
@click.option(
"--annotations",
default=None,
type=JSONTypeClass(),
multiple=False,
)
@click.pass_context
def step(
ctx,
Expand All @@ -130,6 +142,8 @@ def step(
run_time_limit=None,
persistent_volume_claims=None,
tolerations=None,
labels=None,
annotations=None,
**kwargs
):
def echo(msg, stream="stderr", job_id=None, **kwargs):
Expand Down Expand Up @@ -244,6 +258,8 @@ def _sync_metadata():
env=env,
persistent_volume_claims=persistent_volume_claims,
tolerations=tolerations,
labels=labels,
annotations=annotations,
)
except Exception as e:
traceback.print_exc(chain=False)
Expand Down
Loading

0 comments on commit f8861b8

Please sign in to comment.