Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Kubernetes tolerations #1207

Merged
merged 22 commits into from
Dec 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions metaflow/metaflow_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@
KUBERNETES_SERVICE_ACCOUNT = from_conf("KUBERNETES_SERVICE_ACCOUNT")
# Default node selectors to use by K8S jobs created by Metaflow - foo=bar,baz=bab
KUBERNETES_NODE_SELECTOR = from_conf("KUBERNETES_NODE_SELECTOR", "")
KUBERNETES_TOLERATIONS = from_conf("KUBERNETES_TOLERATIONS", "")
KUBERNETES_SECRETS = from_conf("KUBERNETES_SECRETS", "")
# Default GPU vendor to use by K8S jobs created by Metaflow (supports nvidia, amd)
KUBERNETES_GPU_VENDOR = from_conf("KUBERNETES_GPU_VENDOR", "nvidia")
Expand Down
37 changes: 19 additions & 18 deletions metaflow/plugins/argo/argo_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,9 @@ def _sanitize(name):
return name.replace("_", "-")

@classmethod
def trigger(cls, name, parameters={}):
def trigger(cls, name, parameters=None):
if parameters is None:
savingoyal marked this conversation as resolved.
Show resolved Hide resolved
parameters = {}
try:
workflow_template = ArgoClient(
namespace=KUBERNETES_NAMESPACE
Expand Down Expand Up @@ -378,14 +380,18 @@ def _compile(self):

# Visit every node and yield the uber DAGTemplate(s).
def _dag_templates(self):
def _visit(node, exit_node=None, templates=[], dag_tasks=[]):
def _visit(node, exit_node=None, templates=None, dag_tasks=None):
# Every for-each node results in a separate subDAG and an equivalent
# DAGTemplate rooted at the child of the for-each node. Each DAGTemplate
# has a unique name - the top-level DAGTemplate is named as the name of
# the flow and the subDAG DAGTemplates are named after the (only) descendant
# of the for-each node.

# Emit if we have reached the end of the sub workflow
if dag_tasks is None:
dag_tasks = []
if templates is None:
templates = []
if exit_node is not None and exit_node is node.name:
return templates, dag_tasks

Expand Down Expand Up @@ -859,7 +865,8 @@ def _container_templates(self):
.retry_strategy(
times=total_retries,
minutes_between_retries=minutes_between_retries,
).metadata(
)
.metadata(
ObjectMeta().annotation("metaflow/step_name", node.name)
# Unfortunately, we can't set the task_id since it is generated
# inside the pod. However, it can be inferred from the annotation
Expand All @@ -871,19 +878,8 @@ def _container_templates(self):
# Set emptyDir volume for state management
.empty_dir_volume("out")
# Set node selectors
.node_selectors(
{
str(k.split("=", 1)[0]): str(k.split("=", 1)[1])
for k in (
resources.get("node_selector")
or (
KUBERNETES_NODE_SELECTOR.split(",")
if KUBERNETES_NODE_SELECTOR
else []
)
)
}
)
.node_selectors(resources.get("node_selector"))
.tolerations(resources.get("tolerations"))
# Set container
.container(
# TODO: Unify the logic with kubernetes.py
Expand Down Expand Up @@ -1234,8 +1230,13 @@ def empty_dir_volume(self, name):

def node_selectors(self, node_selectors):
if "nodeSelector" not in self.payload:
self.payload["labels"] = {}
self.payload["nodeSelector"].update(node_selectors)
self.payload["nodeSelector"] = {}
savingoyal marked this conversation as resolved.
Show resolved Hide resolved
if node_selectors:
self.payload["nodeSelector"].update(node_selectors)
return self

def tolerations(self, tolerations):
self.payload["tolerations"] = tolerations
return self

def to_json(self):
Expand Down
7 changes: 6 additions & 1 deletion metaflow/plugins/kubernetes/kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,13 @@ def create_job(
disk=None,
memory=None,
run_time_limit=None,
env={},
env=None,
tolerations=None,
):

if env is None:
env = {}

job = (
KubernetesClient()
.job(
Expand Down Expand Up @@ -177,6 +181,7 @@ def create_job(
# Retries are handled by Metaflow runtime
retries=0,
step_name=step_name,
tolerations=tolerations,
)
.environment_variable("METAFLOW_CODE_SHA", code_package_sha)
.environment_variable("METAFLOW_CODE_URL", code_package_url)
Expand Down
14 changes: 13 additions & 1 deletion metaflow/plugins/kubernetes/kubernetes_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
import time
import traceback

from metaflow import util
from metaflow import util, JSONTypeClass
from metaflow._vendor import click
from metaflow.exception import METAFLOW_EXIT_DISALLOW_RETRY, CommandException
from metaflow.metadata.util import sync_local_metadata_from_datastore
from metaflow.metaflow_config import DATASTORE_LOCAL_DIR
from metaflow.mflog import TASK_LOG_SOURCE

from .kubernetes import Kubernetes, KubernetesKilledException
from .kubernetes_decorator import KubernetesDecorator


@click.group()
Expand Down Expand Up @@ -84,6 +85,12 @@ def kubernetes():
default=5 * 24 * 60 * 60, # Default is set to 5 days
help="Run time limit in seconds for Kubernetes pod.",
)
@click.option(
"--tolerations",
default=None,
type=JSONTypeClass(),
multiple=False,
odracci marked this conversation as resolved.
Show resolved Hide resolved
)
@click.pass_context
def step(
ctx,
Expand All @@ -102,6 +109,7 @@ def step(
gpu=None,
gpu_vendor=None,
run_time_limit=None,
tolerations=None,
**kwargs
):
def echo(msg, stream="stderr", job_id=None):
Expand Down Expand Up @@ -166,6 +174,9 @@ def echo(msg, stream="stderr", job_id=None):
stdout_location = ds.get_log_location(TASK_LOG_SOURCE, "stdout")
stderr_location = ds.get_log_location(TASK_LOG_SOURCE, "stderr")

# `node_selector` is a tuple of strings, convert it to a dictionary
node_selector = KubernetesDecorator.parse_node_selector(node_selector)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this needed anymore?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it is. At this stage, node_selector is a tuple of strings. kubernetes.launch_job expects a dictionary

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason not to handle this parsing within kubernetes_job - the actual format is dictated by the kubernetes SDK and that's why currently all the Kubernetes-related formatting is happening within the KubernetesJob object. As the SDK evolves, any changes would be isolated to that object.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

node_selector contains the value generated by

@click.option(
    "--node-selector",
    multiple=True,
    default=None,
    help="NodeSelector for Kubernetes pod.",
)

which is a tuple of strings

('key'='val','foo=bar')

kubernetes_job expect a dictionary like

{
  "key": "val",
  "foo": "bar",
}

parse_node_selector converts the tuple of strings to a dictionary compatible with the Kubernetes SDK.
Does it make sense?


def _sync_metadata():
if ctx.obj.metadata.TYPE == "local":
sync_local_metadata_from_datastore(
Expand Down Expand Up @@ -206,6 +217,7 @@ def _sync_metadata():
gpu_vendor=gpu_vendor,
run_time_limit=run_time_limit,
env=env,
tolerations=tolerations,
)
except Exception as e:
traceback.print_exc(chain=False)
Expand Down
64 changes: 59 additions & 5 deletions metaflow/plugins/kubernetes/kubernetes_decorator.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import json
import os
import platform
import sys

import requests

from metaflow import util
from metaflow.decorators import StepDecorator
from metaflow.exception import MetaflowException
from metaflow.metadata import MetaDatum
Expand All @@ -16,6 +14,7 @@
KUBERNETES_GPU_VENDOR,
KUBERNETES_NAMESPACE,
KUBERNETES_NODE_SELECTOR,
KUBERNETES_TOLERATIONS,
KUBERNETES_SERVICE_ACCOUNT,
KUBERNETES_SECRETS,
)
Expand All @@ -24,6 +23,7 @@
from metaflow.sidecar import Sidecar

from ..aws.aws_utils import get_docker_registry

from .kubernetes import KubernetesException

try:
Expand Down Expand Up @@ -65,6 +65,10 @@ class KubernetesDecorator(StepDecorator):
Kubernetes secrets to use when launching pod in Kubernetes. These
secrets are in addition to the ones defined in `METAFLOW_KUBERNETES_SECRETS`
in Metaflow configuration.
tolerations : List[str]
Kubernetes tolerations to use when launching pod in Kubernetes. If
not specified, the value of `METAFLOW_KUBERNETES_TOLERATIONS` is used
from Metaflow configuration.
"""

name = "kubernetes"
Expand All @@ -79,6 +83,8 @@ class KubernetesDecorator(StepDecorator):
"namespace": None,
"gpu": None, # value of 0 implies that the scheduled node should not have GPUs
"gpu_vendor": None,
"tolerations": None, # e.g., [{"key": "arch", "operator": "Equal", "value": "amd"},
# {"key": "foo", "operator": "Equal", "value": "bar"}]
}
package_url = None
package_sha = None
Expand All @@ -93,9 +99,39 @@ def __init__(self, attributes=None, statically_defined=False):
self.attributes["service_account"] = KUBERNETES_SERVICE_ACCOUNT
if not self.attributes["gpu_vendor"]:
self.attributes["gpu_vendor"] = KUBERNETES_GPU_VENDOR
if not self.attributes["node_selector"] and KUBERNETES_NODE_SELECTOR:
self.attributes["node_selector"] = KUBERNETES_NODE_SELECTOR
if not self.attributes["tolerations"] and KUBERNETES_TOLERATIONS:
self.attributes["tolerations"] = json.loads(KUBERNETES_TOLERATIONS)

if isinstance(self.attributes["node_selector"], str):
self.attributes["node_selector"] = self.parse_node_selector(
self.attributes["node_selector"].split(",")
)

# TODO: Handle node_selector in a better manner. Currently it is special
# cased in kubernetes_client.py
if self.attributes["tolerations"]:
try:
from kubernetes.client import V1Toleration
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand that the rationale for including this check in _init_ is to ensure that this check is invoked for argo-workflows too. However, this check will fail if the user hasn't installed the python package kubernetes yet - which is checked in package_init - that check should technically happen before the check for tolerations.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is checked in kubernetes_cli.
The idea is that the check is invoked only if required, which means the python package kubernetes must be installed. I think the order of the execution doesn't matter.
If kubernetes is not available, self.attributes["tolerations"] is not being used, then the check is not required.
Does it make sense to you?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When the user pip installs metaflow, we don't install Kubernetes python package. It's only when the user starts executing a flow that involves @kubernetes or argo - we throw a nice warning asking them to install the python package. Now, if that first flow has tolerations defined, then the user will instead get an error saying no module named Kubernetes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That import is inside a try block with

except (NameError, ImportError):
  pass

It should not raise any errors related to the missing module, is it correct?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes - but it's the round about way this check is implemented which is my concern. We can ship this and come back to clean it up.


for toleration in self.attributes["tolerations"]:
try:
invalid_keys = [
k
for k in toleration.keys()
if k not in V1Toleration.attribute_map.keys()
]
if len(invalid_keys) > 0:
raise KubernetesException(
"Tolerations parameter contains invalid keys: %s"
% invalid_keys
)
except AttributeError:
raise KubernetesException(
"Unable to parse tolerations: %s"
% self.attributes["tolerations"]
)
except (NameError, ImportError):
pass

# If no docker image is explicitly specified, impute a default image.
if not self.attributes["image"]:
Expand Down Expand Up @@ -248,6 +284,12 @@ def runtime_step_cli(
for k, v in self.attributes.items():
if k == "namespace":
cli_args.command_options["k8s_namespace"] = v
elif k == "node_selector" and v:
cli_args.command_options[k] = ",".join(
["=".join([key, str(val)]) for key, val in v.items()]
)
elif k == "tolerations":
cli_args.command_options[k] = json.dumps(v)
else:
cli_args.command_options[k] = v
cli_args.command_options["run-time-limit"] = self.run_time_limit
Expand Down Expand Up @@ -340,3 +382,15 @@ def _save_package_once(cls, flow_datastore, package):
cls.package_url, cls.package_sha = flow_datastore.save_data(
[package.blob], len_hint=1
)[0]

@staticmethod
def parse_node_selector(node_selector: list):
try:
return {
str(k.split("=", 1)[0]): str(k.split("=", 1)[1])
for k in node_selector or []
}
except (AttributeError, IndexError):
raise KubernetesException(
"Unable to parse node_selector: %s" % node_selector
)
23 changes: 7 additions & 16 deletions metaflow/plugins/kubernetes/kubernetes_job.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import json
import math
import os
import random
import sys
import time

from metaflow.exception import MetaflowException
from metaflow.metaflow_config import KUBERNETES_NODE_SELECTOR, KUBERNETES_SECRETS
from metaflow.metaflow_config import KUBERNETES_SECRETS

CLIENT_REFRESH_INTERVAL_SECONDS = 300

Expand Down Expand Up @@ -71,6 +69,7 @@ def create(self):
# Note: This implementation ensures that there is only one unique Pod
# (unique UID) per Metaflow task attempt.
client = self._client.get()

self._job = client.V1Job(
api_version="batch/v1",
kind="Job",
Expand Down Expand Up @@ -163,17 +162,7 @@ def create(self):
),
)
],
node_selector={
str(k.split("=", 1)[0]): str(k.split("=", 1)[1])
for k in (
self._kwargs.get("node_selector")
or (
KUBERNETES_NODE_SELECTOR.split(",")
if KUBERNETES_NODE_SELECTOR
else []
)
)
},
node_selector=self._kwargs.get("node_selector"),
# TODO (savin): Support image_pull_secrets
# image_pull_secrets=?,
# TODO (savin): Support preemption policies
Expand All @@ -188,8 +177,10 @@ def create(self):
service_account_name=self._kwargs["service_account"],
# Terminate the container immediately on SIGTERM
termination_grace_period_seconds=0,
# TODO (savin): Enable tolerations for GPU scheduling.
# tolerations=?,
tolerations=[
client.V1Toleration(**toleration)
for toleration in self._kwargs.get("tolerations") or []
],
# volumes=?,
# TODO (savin): Set termination_message_policy
),
Expand Down