diff --git a/metaflow/metaflow_config.py b/metaflow/metaflow_config.py index 52cedd89f54..2fe570b0895 100644 --- a/metaflow/metaflow_config.py +++ b/metaflow/metaflow_config.py @@ -322,6 +322,9 @@ ARGO_WORKFLOWS_KUBERNETES_SECRETS = from_conf("ARGO_WORKFLOWS_KUBERNETES_SECRETS", "") ARGO_WORKFLOWS_ENV_VARS_TO_SKIP = from_conf("ARGO_WORKFLOWS_ENV_VARS_TO_SKIP", "") +KUBERNETES_JOBSET_GROUP = from_conf("KUBERNETES_JOBSET_GROUP", "jobset.x-k8s.io") +KUBERNETES_JOBSET_VERSION = from_conf("KUBERNETES_JOBSET_VERSION", "v1alpha2") + ## # Argo Events Configuration ## diff --git a/metaflow/plugins/argo/argo_workflows.py b/metaflow/plugins/argo/argo_workflows.py index 572ca205af0..5cd0313fdda 100644 --- a/metaflow/plugins/argo/argo_workflows.py +++ b/metaflow/plugins/argo/argo_workflows.py @@ -838,6 +838,11 @@ def _dag_templates(self): def _visit( node, exit_node=None, templates=None, dag_tasks=None, parent_foreach=None ): + if node.parallel_foreach: + raise ArgoWorkflowsException( + "Deploying flows with @parallel decorator(s) " + "as Argo Workflows is not supported currently." + ) # Every for-each node results in a separate subDAG and an equivalent # DAGTemplate rooted at the child of the for-each node. Each DAGTemplate # has a unique name - the top-level DAGTemplate is named as the name of diff --git a/metaflow/plugins/kubernetes/kubernetes.py b/metaflow/plugins/kubernetes/kubernetes.py index c87d3c221de..c6bfe38e9ca 100644 --- a/metaflow/plugins/kubernetes/kubernetes.py +++ b/metaflow/plugins/kubernetes/kubernetes.py @@ -3,9 +3,9 @@ import os import re import shlex +import copy import time from typing import Dict, List, Optional -import uuid from uuid import uuid4 from metaflow import current, util @@ -66,6 +66,12 @@ class KubernetesKilledException(MetaflowException): headline = "Kubernetes Batch job killed" +def _extract_labels_and_annotations_from_job_spec(job_spec): + annotations = job_spec.template.metadata.annotations + labels = job_spec.template.metadata.labels + return copy.copy(annotations), copy.copy(labels) + + class Kubernetes(object): def __init__( self, @@ -140,9 +146,64 @@ def _command( return shlex.split('bash -c "%s"' % cmd_str) def launch_job(self, **kwargs): - self._job = self.create_job(**kwargs).execute() + if ( + "num_parallel" in kwargs + and kwargs["num_parallel"] + and int(kwargs["num_parallel"]) > 0 + ): + job = self.create_job_object(**kwargs) + spec = job.create_job_spec() + # `kwargs["step_cli"]` is setting `ubf_context` as control to ALL pods. + # This will be modified by the KubernetesJobSet object + annotations, labels = _extract_labels_and_annotations_from_job_spec(spec) + self._job = self.create_jobset( + job_spec=spec, + run_id=kwargs["run_id"], + step_name=kwargs["step_name"], + task_id=kwargs["task_id"], + namespace=kwargs["namespace"], + env=kwargs["env"], + num_parallel=kwargs["num_parallel"], + port=kwargs["port"], + annotations=annotations, + labels=labels, + ).execute() + else: + kwargs["name_pattern"] = "t-{uid}-".format(uid=str(uuid4())[:8]) + self._job = self.create_job_object(**kwargs).k8screate().execute() + + def create_jobset( + self, + job_spec=None, + run_id=None, + step_name=None, + task_id=None, + namespace=None, + env=None, + num_parallel=None, + port=None, + annotations=None, + labels=None, + ): + if env is None: + env = {} - def create_job( + _prefix = str(uuid4())[:6] + js = KubernetesClient().jobset( + name="js-%s" % _prefix, + run_id=run_id, + task_id=task_id, + step_name=step_name, + namespace=namespace, + labels=self._get_labels(labels), + annotations=annotations, + num_parallel=num_parallel, + job_spec=job_spec, + port=port, + ) + return js + + def create_job_object( self, flow_name, run_id, @@ -176,14 +237,15 @@ def create_job( labels=None, shared_memory=None, port=None, + name_pattern=None, + num_parallel=None, ): if env is None: env = {} - job = ( KubernetesClient() .job( - generate_name="t-{uid}-".format(uid=str(uuid4())[:8]), + generate_name=name_pattern, namespace=namespace, service_account=service_account, secrets=secrets, @@ -217,6 +279,7 @@ def create_job( persistent_volume_claims=persistent_volume_claims, shared_memory=shared_memory, port=port, + num_parallel=num_parallel, ) .environment_variable("METAFLOW_CODE_SHA", code_package_sha) .environment_variable("METAFLOW_CODE_URL", code_package_url) @@ -332,6 +395,9 @@ def create_job( .label("app.kubernetes.io/part-of", "metaflow") ) + return job + + def create_k8sjob(self, job): return job.create() def wait(self, stdout_location, stderr_location, echo=None): @@ -366,7 +432,7 @@ def wait_for_launch(job): t = time.time() time.sleep(update_delay(time.time() - start_time)) - prefix = b"[%s] " % util.to_bytes(self._job.id) + _make_prefix = lambda: b"[%s] " % util.to_bytes(self._job.id) stdout_tail = get_log_tailer(stdout_location, self._datastore.TYPE) stderr_tail = get_log_tailer(stderr_location, self._datastore.TYPE) @@ -376,7 +442,7 @@ def wait_for_launch(job): # 2) Tail logs until the job has finished tail_logs( - prefix=prefix, + prefix=_make_prefix(), stdout_tail=stdout_tail, stderr_tail=stderr_tail, echo=echo, @@ -392,7 +458,6 @@ def wait_for_launch(job): # exists prior to calling S3Tail and note the user about # truncated logs if it doesn't. # TODO : For hard crashes, we can fetch logs from the pod. - if self._job.has_failed: exit_code, reason = self._job.reason msg = next( diff --git a/metaflow/plugins/kubernetes/kubernetes_cli.py b/metaflow/plugins/kubernetes/kubernetes_cli.py index 9d4750f45f6..3c32d4c4dd0 100644 --- a/metaflow/plugins/kubernetes/kubernetes_cli.py +++ b/metaflow/plugins/kubernetes/kubernetes_cli.py @@ -7,6 +7,7 @@ from metaflow._vendor import click from metaflow.exception import METAFLOW_EXIT_DISALLOW_RETRY, CommandException from metaflow.metadata.util import sync_local_metadata_from_datastore +from metaflow.unbounded_foreach import UBF_CONTROL, UBF_TASK from metaflow.metaflow_config import DATASTORE_LOCAL_DIR, KUBERNETES_LABELS from metaflow.mflog import TASK_LOG_SOURCE import metaflow.tracing as tracing @@ -109,6 +110,15 @@ def kubernetes(): ) @click.option("--shared-memory", default=None, help="Size of shared memory in MiB") @click.option("--port", default=None, help="Port number to expose from the container") +@click.option( + "--ubf-context", default=None, type=click.Choice([None, UBF_CONTROL, UBF_TASK]) +) +@click.option( + "--num-parallel", + default=None, + type=int, + help="Number of parallel nodes to run as a multi-node job.", +) @click.pass_context def step( ctx, @@ -136,6 +146,7 @@ def step( tolerations=None, shared_memory=None, port=None, + num_parallel=None, **kwargs ): def echo(msg, stream="stderr", job_id=None, **kwargs): @@ -251,6 +262,7 @@ def _sync_metadata(): tolerations=tolerations, shared_memory=shared_memory, port=port, + num_parallel=num_parallel, ) except Exception as e: traceback.print_exc(chain=False) diff --git a/metaflow/plugins/kubernetes/kubernetes_client.py b/metaflow/plugins/kubernetes/kubernetes_client.py index 33023f36c11..631d2ecdf13 100644 --- a/metaflow/plugins/kubernetes/kubernetes_client.py +++ b/metaflow/plugins/kubernetes/kubernetes_client.py @@ -4,7 +4,7 @@ from metaflow.exception import MetaflowException -from .kubernetes_job import KubernetesJob +from .kubernetes_job import KubernetesJob, KubernetesJobSet CLIENT_REFRESH_INTERVAL_SECONDS = 300 @@ -61,5 +61,8 @@ def get(self): return self._client + def jobset(self, **kwargs): + return KubernetesJobSet(self, **kwargs) + def job(self, **kwargs): return KubernetesJob(self, **kwargs) diff --git a/metaflow/plugins/kubernetes/kubernetes_decorator.py b/metaflow/plugins/kubernetes/kubernetes_decorator.py index b6253cb5841..f68c53c74d8 100644 --- a/metaflow/plugins/kubernetes/kubernetes_decorator.py +++ b/metaflow/plugins/kubernetes/kubernetes_decorator.py @@ -32,6 +32,8 @@ from ..aws.aws_utils import get_docker_registry, get_ec2_instance_metadata from .kubernetes import KubernetesException, parse_kube_keyvalue_list +from metaflow.unbounded_foreach import UBF_CONTROL +from .kubernetes_jobsets import TaskIdConstructor try: unicode @@ -239,12 +241,6 @@ def step_init(self, flow, graph, step, decos, environment, flow_datastore, logge "Kubernetes. Please use one or the other.".format(step=step) ) - for deco in decos: - if getattr(deco, "IS_PARALLEL", False): - raise KubernetesException( - "@kubernetes does not support parallel execution currently." - ) - # Set run time limit for the Kubernetes job. self.run_time_limit = get_run_time_limit_for_task(decos) if self.run_time_limit < 60: @@ -453,6 +449,24 @@ def task_pre_step( self._save_logs_sidecar = Sidecar("save_logs_periodically") self._save_logs_sidecar.start() + num_parallel = None + if hasattr(flow, "_parallel_ubf_iter"): + num_parallel = flow._parallel_ubf_iter.num_parallel + + if num_parallel and num_parallel >= 1 and ubf_context == UBF_CONTROL: + control_task_id, worker_task_ids = TaskIdConstructor.join_step_task_ids( + num_parallel + ) + mapper_task_ids = [control_task_id] + worker_task_ids + flow._control_mapper_tasks = [ + "%s/%s/%s" % (run_id, step_name, mapper_task_id) + for mapper_task_id in mapper_task_ids + ] + flow._control_task_is_mapper_zero = True + + if num_parallel and num_parallel > 1: + _setup_multinode_environment() + def task_finished( self, step_name, flow, graph, is_task_ok, retry_count, max_retries ): @@ -486,3 +500,20 @@ def _save_package_once(cls, flow_datastore, package): cls.package_url, cls.package_sha = flow_datastore.save_data( [package.blob], len_hint=1 )[0] + + +def _setup_multinode_environment(): + import socket + + os.environ["MF_PARALLEL_MAIN_IP"] = socket.gethostbyname(os.environ["MASTER_ADDR"]) + os.environ["MF_PARALLEL_NUM_NODES"] = os.environ["WORLD_SIZE"] + if os.environ.get("CONTROL_INDEX") is not None: + os.environ["MF_PARALLEL_NODE_INDEX"] = str(0) + elif os.environ.get("WORKER_REPLICA_INDEX") is not None: + os.environ["MF_PARALLEL_NODE_INDEX"] = str( + int(os.environ["WORKER_REPLICA_INDEX"]) + 1 + ) + else: + raise MetaflowException( + "Jobset related ENV vars called $CONTROL_INDEX or $WORKER_REPLICA_INDEX not found" + ) diff --git a/metaflow/plugins/kubernetes/kubernetes_job.py b/metaflow/plugins/kubernetes/kubernetes_job.py index adb9446e5f9..0ea3cf822e3 100644 --- a/metaflow/plugins/kubernetes/kubernetes_job.py +++ b/metaflow/plugins/kubernetes/kubernetes_job.py @@ -2,14 +2,17 @@ import math import random import time - +import copy +import sys from metaflow.tracing import inject_tracing_vars - - from metaflow.exception import MetaflowException from metaflow.metaflow_config import KUBERNETES_SECRETS +from metaflow.unbounded_foreach import UBF_CONTROL, UBF_TASK CLIENT_REFRESH_INTERVAL_SECONDS = 300 +from .kubernetes_jobsets import ( + KubernetesJobSet, # We need this import for Kubernetes Client. +) class KubernetesJobException(MetaflowException): @@ -58,7 +61,206 @@ def __init__(self, client, **kwargs): self._client = client self._kwargs = kwargs - def create(self): + def create_job_spec(self): + client = self._client.get() + + # tmpfs variables + use_tmpfs = self._kwargs["use_tmpfs"] + tmpfs_size = self._kwargs["tmpfs_size"] + tmpfs_enabled = use_tmpfs or (tmpfs_size and not use_tmpfs) + shared_memory = ( + int(self._kwargs["shared_memory"]) + if self._kwargs["shared_memory"] + else None + ) + return client.V1JobSpec( + # Retries are handled by Metaflow when it is responsible for + # executing the flow. The responsibility is moved to Kubernetes + # when Argo Workflows is responsible for the execution. + backoff_limit=self._kwargs.get("retries", 0), + completions=self._kwargs.get("completions", 1), + ttl_seconds_after_finished=7 + * 60 + * 60 # Remove job after a week. TODO: Make this configurable + * 24, + template=client.V1PodTemplateSpec( + metadata=client.V1ObjectMeta( + annotations=self._kwargs.get("annotations", {}), + labels=self._kwargs.get("labels", {}), + namespace=self._kwargs["namespace"], + ), + spec=client.V1PodSpec( + # Timeout is set on the pod and not the job (important!) + active_deadline_seconds=self._kwargs["timeout_in_seconds"], + # TODO (savin): Enable affinities for GPU scheduling. + # affinity=?, + containers=[ + client.V1Container( + command=self._kwargs["command"], + ports=[] + if self._kwargs["port"] is None + else [ + client.V1ContainerPort( + container_port=int(self._kwargs["port"]) + ) + ], + env=[ + client.V1EnvVar(name=k, value=str(v)) + for k, v in self._kwargs.get( + "environment_variables", {} + ).items() + ] + # And some downward API magic. Add (key, value) + # pairs below to make pod metadata available + # within Kubernetes container. + + [ + client.V1EnvVar( + name=k, + value_from=client.V1EnvVarSource( + field_ref=client.V1ObjectFieldSelector( + field_path=str(v) + ) + ), + ) + for k, v in { + "METAFLOW_KUBERNETES_POD_NAMESPACE": "metadata.namespace", + "METAFLOW_KUBERNETES_POD_NAME": "metadata.name", + "METAFLOW_KUBERNETES_POD_ID": "metadata.uid", + "METAFLOW_KUBERNETES_SERVICE_ACCOUNT_NAME": "spec.serviceAccountName", + "METAFLOW_KUBERNETES_NODE_IP": "status.hostIP", + }.items() + ] + + [ + client.V1EnvVar(name=k, value=str(v)) + for k, v in inject_tracing_vars({}).items() + ], + env_from=[ + client.V1EnvFromSource( + secret_ref=client.V1SecretEnvSource( + name=str(k), + # optional=True + ) + ) + for k in list(self._kwargs.get("secrets", [])) + + KUBERNETES_SECRETS.split(",") + if k + ], + image=self._kwargs["image"], + image_pull_policy=self._kwargs["image_pull_policy"], + name=self._kwargs["step_name"].replace("_", "-"), + resources=client.V1ResourceRequirements( + requests={ + "cpu": str(self._kwargs["cpu"]), + "memory": "%sM" % str(self._kwargs["memory"]), + "ephemeral-storage": "%sM" + % str(self._kwargs["disk"]), + }, + limits={ + "%s.com/gpu".lower() + % self._kwargs["gpu_vendor"]: str( + self._kwargs["gpu"] + ) + for k in [0] + # Don't set GPU limits if gpu isn't specified. + if self._kwargs["gpu"] is not None + }, + ), + volume_mounts=( + [ + client.V1VolumeMount( + mount_path=self._kwargs.get("tmpfs_path"), + name="tmpfs-ephemeral-volume", + ) + ] + if tmpfs_enabled + else [] + ) + + ( + [ + client.V1VolumeMount( + mount_path="/dev/shm", name="dhsm" + ) + ] + if shared_memory + else [] + ) + + ( + [ + client.V1VolumeMount(mount_path=path, name=claim) + for claim, path in self._kwargs[ + "persistent_volume_claims" + ].items() + ] + if self._kwargs["persistent_volume_claims"] is not None + else [] + ), + ) + ], + node_selector=self._kwargs.get("node_selector"), + # TODO (savin): Support image_pull_secrets + # image_pull_secrets=?, + # TODO (savin): Support preemption policies + # preemption_policy=?, + # + # A Container in a Pod may fail for a number of + # reasons, such as because the process in it exited + # with a non-zero exit code, or the Container was + # killed due to OOM etc. If this happens, fail the pod + # and let Metaflow handle the retries. + restart_policy="Never", + service_account_name=self._kwargs["service_account"], + # Terminate the container immediately on SIGTERM + termination_grace_period_seconds=0, + tolerations=[ + client.V1Toleration(**toleration) + for toleration in self._kwargs.get("tolerations") or [] + ], + volumes=( + [ + client.V1Volume( + name="tmpfs-ephemeral-volume", + empty_dir=client.V1EmptyDirVolumeSource( + medium="Memory", + # Add default unit as ours differs from Kubernetes default. + size_limit="{}Mi".format(tmpfs_size), + ), + ) + ] + if tmpfs_enabled + else [] + ) + + ( + [ + client.V1Volume( + name="dhsm", + empty_dir=client.V1EmptyDirVolumeSource( + medium="Memory", + size_limit="{}Mi".format(shared_memory), + ), + ) + ] + if shared_memory + else [] + ) + + ( + [ + client.V1Volume( + name=claim, + persistent_volume_claim=client.V1PersistentVolumeClaimVolumeSource( + claim_name=claim + ), + ) + for claim in self._kwargs["persistent_volume_claims"].keys() + ] + if self._kwargs["persistent_volume_claims"] is not None + else [] + ), + # TODO (savin): Set termination_message_policy + ), + ), + ) + + def k8screate(self): # A discerning eye would notice and question the choice of using the # V1Job construct over the V1Pod construct given that we don't rely much # on any of the V1Job semantics. The major reasons at the moment are - @@ -77,11 +279,6 @@ def create(self): use_tmpfs = self._kwargs["use_tmpfs"] tmpfs_size = self._kwargs["tmpfs_size"] tmpfs_enabled = use_tmpfs or (tmpfs_size and not use_tmpfs) - shared_memory = ( - int(self._kwargs["shared_memory"]) - if self._kwargs["shared_memory"] - else None - ) self._job = client.V1Job( api_version="batch/v1", @@ -94,197 +291,7 @@ def create(self): generate_name=self._kwargs["generate_name"], namespace=self._kwargs["namespace"], # Defaults to `default` ), - spec=client.V1JobSpec( - # Retries are handled by Metaflow when it is responsible for - # executing the flow. The responsibility is moved to Kubernetes - # when Argo Workflows is responsible for the execution. - backoff_limit=self._kwargs.get("retries", 0), - completions=1, # A single non-indexed pod job - ttl_seconds_after_finished=7 - * 60 - * 60 # Remove job after a week. TODO: Make this configurable - * 24, - template=client.V1PodTemplateSpec( - metadata=client.V1ObjectMeta( - annotations=self._kwargs.get("annotations", {}), - labels=self._kwargs.get("labels", {}), - namespace=self._kwargs["namespace"], - ), - spec=client.V1PodSpec( - # Timeout is set on the pod and not the job (important!) - active_deadline_seconds=self._kwargs["timeout_in_seconds"], - # TODO (savin): Enable affinities for GPU scheduling. - # affinity=?, - containers=[ - client.V1Container( - command=self._kwargs["command"], - ports=[ - client.V1ContainerPort( - container_port=int(self._kwargs["port"]) - ) - ] - if "port" in self._kwargs and self._kwargs["port"] - else None, - env=[ - client.V1EnvVar(name=k, value=str(v)) - for k, v in self._kwargs.get( - "environment_variables", {} - ).items() - ] - # And some downward API magic. Add (key, value) - # pairs below to make pod metadata available - # within Kubernetes container. - + [ - client.V1EnvVar( - name=k, - value_from=client.V1EnvVarSource( - field_ref=client.V1ObjectFieldSelector( - field_path=str(v) - ) - ), - ) - for k, v in { - "METAFLOW_KUBERNETES_POD_NAMESPACE": "metadata.namespace", - "METAFLOW_KUBERNETES_POD_NAME": "metadata.name", - "METAFLOW_KUBERNETES_POD_ID": "metadata.uid", - "METAFLOW_KUBERNETES_SERVICE_ACCOUNT_NAME": "spec.serviceAccountName", - "METAFLOW_KUBERNETES_NODE_IP": "status.hostIP", - }.items() - ] - + [ - client.V1EnvVar(name=k, value=str(v)) - for k, v in inject_tracing_vars({}).items() - ], - env_from=[ - client.V1EnvFromSource( - secret_ref=client.V1SecretEnvSource( - name=str(k), - # optional=True - ) - ) - for k in list(self._kwargs.get("secrets", [])) - + KUBERNETES_SECRETS.split(",") - if k - ], - image=self._kwargs["image"], - image_pull_policy=self._kwargs["image_pull_policy"], - name=self._kwargs["step_name"].replace("_", "-"), - resources=client.V1ResourceRequirements( - requests={ - "cpu": str(self._kwargs["cpu"]), - "memory": "%sM" % str(self._kwargs["memory"]), - "ephemeral-storage": "%sM" - % str(self._kwargs["disk"]), - }, - limits={ - "%s.com/gpu".lower() - % self._kwargs["gpu_vendor"]: str( - self._kwargs["gpu"] - ) - for k in [0] - # Don't set GPU limits if gpu isn't specified. - if self._kwargs["gpu"] is not None - }, - ), - volume_mounts=( - [ - client.V1VolumeMount( - mount_path=self._kwargs.get("tmpfs_path"), - name="tmpfs-ephemeral-volume", - ) - ] - if tmpfs_enabled - else [] - ) - + ( - [ - client.V1VolumeMount( - mount_path="/dev/shm", name="dhsm" - ) - ] - if shared_memory - else [] - ) - + ( - [ - client.V1VolumeMount( - mount_path=path, name=claim - ) - for claim, path in self._kwargs[ - "persistent_volume_claims" - ].items() - ] - if self._kwargs["persistent_volume_claims"] - is not None - else [] - ), - ) - ], - node_selector=self._kwargs.get("node_selector"), - # TODO (savin): Support image_pull_secrets - # image_pull_secrets=?, - # TODO (savin): Support preemption policies - # preemption_policy=?, - # - # A Container in a Pod may fail for a number of - # reasons, such as because the process in it exited - # with a non-zero exit code, or the Container was - # killed due to OOM etc. If this happens, fail the pod - # and let Metaflow handle the retries. - restart_policy="Never", - service_account_name=self._kwargs["service_account"], - # Terminate the container immediately on SIGTERM - termination_grace_period_seconds=0, - tolerations=[ - client.V1Toleration(**toleration) - for toleration in self._kwargs.get("tolerations") or [] - ], - volumes=( - [ - client.V1Volume( - name="tmpfs-ephemeral-volume", - empty_dir=client.V1EmptyDirVolumeSource( - medium="Memory", - # Add default unit as ours differs from Kubernetes default. - size_limit="{}Mi".format(tmpfs_size), - ), - ) - ] - if tmpfs_enabled - else [] - ) - + ( - [ - client.V1Volume( - name="dhsm", - empty_dir=client.V1EmptyDirVolumeSource( - medium="Memory", - size_limit="{}Mi".format(shared_memory), - ), - ) - ] - if shared_memory - else [] - ) - + ( - [ - client.V1Volume( - name=claim, - persistent_volume_claim=client.V1PersistentVolumeClaimVolumeSource( - claim_name=claim - ), - ) - for claim in self._kwargs[ - "persistent_volume_claims" - ].keys() - ] - if self._kwargs["persistent_volume_claims"] is not None - else [] - ), - # TODO (savin): Set termination_message_policy - ), - ), - ), + spec=self.create_job_spec(), ) return self @@ -418,7 +425,7 @@ def __init__(self, client, name, uid, namespace): def best_effort_kill(): try: self.kill() - except: + except Exception as ex: pass atexit.register(best_effort_kill) @@ -482,9 +489,9 @@ def kill(self): # 3. If the pod object hasn't shown up yet, we set the parallelism to 0 # to preempt it. client = self._client.get() + if not self.is_done: if self.is_running: - # Case 1. from kubernetes.stream import stream diff --git a/metaflow/plugins/kubernetes/kubernetes_jobsets.py b/metaflow/plugins/kubernetes/kubernetes_jobsets.py new file mode 100644 index 00000000000..7b4ea247ac3 --- /dev/null +++ b/metaflow/plugins/kubernetes/kubernetes_jobsets.py @@ -0,0 +1,753 @@ +import copy +import math +import random +import time +from metaflow.metaflow_current import current +from metaflow.exception import MetaflowException +from metaflow.unbounded_foreach import UBF_CONTROL, UBF_TASK +import json +from metaflow.metaflow_config import KUBERNETES_JOBSET_GROUP, KUBERNETES_JOBSET_VERSION +from collections import namedtuple + + +class KubernetesJobsetException(MetaflowException): + headline = "Kubernetes jobset error" + + +# TODO [DUPLICATE CODE]: Refactor this method to a separate file so that +# It can be used by both KubernetesJob and KubernetesJobset +def k8s_retry(deadline_seconds=60, max_backoff=32): + def decorator(function): + from functools import wraps + + @wraps(function) + def wrapper(*args, **kwargs): + from kubernetes import client + + deadline = time.time() + deadline_seconds + retry_number = 0 + + while True: + try: + result = function(*args, **kwargs) + return result + except client.rest.ApiException as e: + if e.status == 500: + current_t = time.time() + backoff_delay = min( + math.pow(2, retry_number) + random.random(), max_backoff + ) + if current_t + backoff_delay < deadline: + time.sleep(backoff_delay) + retry_number += 1 + continue # retry again + else: + raise + else: + raise + + return wrapper + + return decorator + + +JobsetStatus = namedtuple( + "JobsetStatus", + [ + "control_pod_failed", # boolean + "control_exit_code", + "control_pod_status", # string like ():() [used for user-messaging] + "control_started", + "control_completed", + "worker_pods_failed", + "workers_are_suspended", + "workers_have_started", + "all_jobs_are_suspended", + "jobset_finished", + "jobset_failed", + "status_unknown", + "jobset_was_terminated", + "some_jobs_are_running", + ], +) + + +def _basic_validation_for_js(jobset): + if not jobset.get("status") or not jobset.get("status").get("replicatedJobsStatus"): + return False + worker_jobs = [ + w for w in jobset.get("spec").get("replicatedJobs") if w["name"] == "worker" + ] + if len(worker_jobs) == 0: + raise KubernetesJobsetException("No worker jobs found in the jobset manifest") + control_job = [ + w for w in jobset.get("spec").get("replicatedJobs") if w["name"] == "control" + ] + if len(control_job) == 0: + raise KubernetesJobsetException("No control job found in the jobset manifest") + return True + + +def _derive_pod_status_and_status_code(control_pod): + overall_status = None + control_exit_code = None + control_pod_failed = False + if control_pod: + container_status = None + pod_status = control_pod.get("status", {}).get("phase") + container_statuses = control_pod.get("status", {}).get("containerStatuses") + if container_statuses is None: + container_status = ": ".join( + filter( + None, + [ + control_pod.get("status", {}).get("reason"), + control_pod.get("status", {}).get("message"), + ], + ) + ) + else: + for k, v in container_statuses[0].get("state", {}).items(): + if v is not None: + control_exit_code = v.get("exit_code") + container_status = ": ".join( + filter( + None, + [v.get("reason"), v.get("message")], + ) + ) + if container_status is None: + overall_status = ( + f"pod status: {pod_status} | container status: {container_status}" + ) + else: + overall_status = f"pod status: {pod_status}" + if pod_status == "Failed": + control_pod_failed = True + return overall_status, control_exit_code, control_pod_failed + + +def _construct_jobset_logical_status(jobset, control_pod=None): + if not _basic_validation_for_js(jobset): + return JobsetStatus( + control_started=False, + control_completed=False, + workers_are_suspended=False, + workers_have_started=False, + all_jobs_are_suspended=False, + jobset_finished=False, + jobset_failed=False, + status_unknown=True, + jobset_was_terminated=False, + control_exit_code=None, + control_pod_status=None, + worker_pods_failed=False, + control_pod_failed=False, + some_jobs_are_running=False, + ) + + js_status = jobset.get("status") + + control_started = False + control_completed = False + workers_are_suspended = False + workers_have_started = False + all_jobs_are_suspended = jobset.get("spec", {}).get("suspend", False) + jobset_finished = False + jobset_failed = False + status_unknown = False + jobset_was_terminated = False + worker_pods_failed = False + some_jobs_are_running = False + + total_worker_jobs = [ + w["replicas"] + for w in jobset.get("spec").get("replicatedJobs", []) + if w["name"] == "worker" + ][0] + total_control_jobs = [ + w["replicas"] + for w in jobset.get("spec").get("replicatedJobs", []) + if w["name"] == "control" + ][0] + + if total_worker_jobs == 0 and total_control_jobs == 0: + jobset_was_terminated = True + + replicated_job_statuses = js_status.get("replicatedJobsStatus") + for job_status in replicated_job_statuses: + if job_status["active"] > 0: + some_jobs_are_running = True + + if job_status["name"] == "control": + control_started = job_status["active"] > 0 or job_status["succeeded"] > 0 + control_completed = job_status["succeeded"] > 0 + if job_status["failed"] > 0: + jobset_failed = True + + if job_status["name"] == "worker": + workers_have_started = job_status["active"] == total_worker_jobs + workers_are_suspended = job_status["suspended"] > 0 + if job_status["failed"] > 0: + worker_pods_failed = True + jobset_failed = True + + if js_status.get("conditions"): + for condition in js_status["conditions"]: + if condition["type"] == "Completed": + jobset_finished = True + if condition["type"] == "Failed": + jobset_failed = True + + ( + overall_status, + control_exit_code, + control_pod_failed, + ) = _derive_pod_status_and_status_code(control_pod) + + return JobsetStatus( + control_started=control_started, + control_completed=control_completed, + workers_are_suspended=workers_are_suspended, + workers_have_started=workers_have_started, + all_jobs_are_suspended=all_jobs_are_suspended, + jobset_finished=jobset_finished, + jobset_failed=jobset_failed, + status_unknown=status_unknown, + jobset_was_terminated=jobset_was_terminated, + control_exit_code=control_exit_code, + control_pod_status=overall_status, + worker_pods_failed=worker_pods_failed, + control_pod_failed=control_pod_failed, + some_jobs_are_running=some_jobs_are_running, + ) + + +class RunningJobSet(object): + def __init__(self, client, name, namespace, group, version): + self._client = client + self._name = name + self._pod_name = None + self._namespace = namespace + self._group = group + self._version = version + self._pod = self._fetch_pod() + self._jobset = self._fetch_jobset() + + import atexit + + def best_effort_kill(): + try: + self.kill() + except Exception as ex: + pass + + atexit.register(best_effort_kill) + + def __repr__(self): + return "{}('{}/{}')".format( + self.__class__.__name__, self._namespace, self._name + ) + + @k8s_retry() + def _fetch_jobset( + self, + ): + # name : name of jobset. + # namespace : namespace of the jobset + # Query the jobset and return the object's status field as a JSON object + client = self._client.get() + with client.ApiClient() as api_client: + api_instance = client.CustomObjectsApi(api_client) + try: + jobset = api_instance.get_namespaced_custom_object( + group=self._group, + version=self._version, + namespace=self._namespace, + plural="jobsets", + name=self._name, + ) + return jobset + except client.rest.ApiException as e: + if e.status == 404: + raise KubernetesJobsetException( + "Unable to locate Kubernetes jobset %s" % self._name + ) + raise + + @k8s_retry() + def _fetch_pod(self): + # Fetch pod metadata. + client = self._client.get() + pods = ( + client.CoreV1Api() + .list_namespaced_pod( + namespace=self._namespace, + label_selector="jobset.sigs.k8s.io/jobset-name={}".format(self._name), + ) + .to_dict()["items"] + ) + if pods: + for pod in pods: + # check the labels of the pod to see if + # the `jobset.sigs.k8s.io/replicatedjob-name` is set to `control` + if ( + pod["metadata"]["labels"].get( + "jobset.sigs.k8s.io/replicatedjob-name" + ) + == "control" + ): + return pod + return {} + + def kill(self): + plural = "jobsets" + client = self._client.get() + # Get the jobset + with client.ApiClient() as api_client: + api_instance = client.CustomObjectsApi(api_client) + try: + jobset = api_instance.get_namespaced_custom_object( + group=self._group, + version=self._version, + namespace=self._namespace, + plural="jobsets", + name=self._name, + ) + + # Suspend the jobset and set the replica's to Zero. + # + jobset["spec"]["suspend"] = True + for replicated_job in jobset["spec"]["replicatedJobs"]: + replicated_job["replicas"] = 0 + + api_instance.replace_namespaced_custom_object( + group=self._group, + version=self._version, + namespace=self._namespace, + plural=plural, + name=jobset["metadata"]["name"], + body=jobset, + ) + except Exception as e: + raise KubernetesJobsetException( + "Exception when suspending existing jobset: %s\n" % e + ) + + @property + def id(self): + if self._pod_name: + return "pod %s" % self._pod_name + if self._pod: + self._pod_name = self._pod["metadata"]["name"] + return self.id + return "jobset %s" % self._name + + @property + def is_done(self): + def done(): + return ( + self._jobset_is_completed + or self._jobset_has_failed + or self._jobset_was_terminated + ) + + if not done(): + # If not done, fetch newer status + self._jobset = self._fetch_jobset() + self._pod = self._fetch_pod() + return done() + + @property + def status(self): + if self.is_done: + return "Jobset is done" + + status = _construct_jobset_logical_status(self._jobset, control_pod=self._pod) + if status.status_unknown: + return "Jobset status is unknown" + if status.control_started: + if status.control_pod_status: + return f"Jobset is running: {status.control_pod_status}" + return "Jobset is running" + if status.all_jobs_are_suspended: + return "Jobset is waiting to be unsuspended" + + return "Jobset waiting for jobs to start" + + @property + def has_succeeded(self): + return self.is_done and self._jobset_is_completed + + @property + def has_failed(self): + return self.is_done and self._jobset_has_failed + + @property + def is_running(self): + if self.is_done: + return False + status = _construct_jobset_logical_status(self._jobset, control_pod=self._pod) + if status.some_jobs_are_running: + return True + return False + + @property + def _jobset_was_terminated(self): + return _construct_jobset_logical_status( + self._jobset, control_pod=self._pod + ).jobset_was_terminated + + @property + def is_waiting(self): + return not self.is_done and not self.is_running + + @property + def reason(self): + # return exit code and reason + if self.is_done and not self.has_succeeded: + self._pod = self._fetch_pod() + elif self.has_succeeded: + return 0, None + status = _construct_jobset_logical_status(self._jobset, control_pod=self._pod) + if status.control_pod_failed: + return ( + status.control_exit_code, + "control-pod failed [%s]" % status.control_pod_status, + ) + elif status.worker_pods_failed: + return None, "Worker pods failed" + return None, None + + @property + def _jobset_is_completed(self): + return _construct_jobset_logical_status( + self._jobset, control_pod=self._pod + ).jobset_finished + + @property + def _jobset_has_failed(self): + return _construct_jobset_logical_status( + self._jobset, control_pod=self._pod + ).jobset_failed + + +class TaskIdConstructor: + @classmethod + def jobset_worker_id(cls, control_task_id: str): + return "".join( + [control_task_id.replace("control", "worker"), "-", "$WORKER_REPLICA_INDEX"] + ) + + @classmethod + def join_step_task_ids(cls, num_parallel): + """ + Called within the step decorator to set the `flow._control_mapper_tasks`. + Setting these allows the flow to know which tasks are needed in the join step. + We set this in the `task_pre_step` method of the decorator. + """ + control_task_id = current.task_id + worker_task_id_base = control_task_id.replace("control", "worker") + mapper = lambda idx: worker_task_id_base + f"-{idx}" + return control_task_id, [mapper(idx) for idx in range(0, num_parallel - 1)] + + @classmethod + def argo(cls): + pass + + +def _jobset_specific_env_vars(client, jobset_main_addr, master_port, num_parallel): + return [ + client.V1EnvVar( + name="MASTER_ADDR", + value=jobset_main_addr, + ), + client.V1EnvVar( + name="MASTER_PORT", + value=str(master_port), + ), + client.V1EnvVar( + name="WORLD_SIZE", + value=str(num_parallel), + ), + ] + [ + client.V1EnvVar( + name="JOBSET_RESTART_ATTEMPT", + value_from=client.V1EnvVarSource( + field_ref=client.V1ObjectFieldSelector( + field_path="metadata.annotations['jobset.sigs.k8s.io/restart-attempt']" + ) + ), + ), + client.V1EnvVar( + name="WORKER_REPLICA_INDEX", + value_from=client.V1EnvVarSource( + field_ref=client.V1ObjectFieldSelector( + field_path="metadata.annotations['jobset.sigs.k8s.io/job-index']" + ) + ), + ), + ] + + +def get_control_job( + client, + job_spec, + jobset_main_addr, + subdomain, + port=None, + num_parallel=None, + namespace=None, +) -> dict: + master_port = port + + job_spec = copy.deepcopy(job_spec) + job_spec.parallelism = 1 + job_spec.completions = 1 + job_spec.template.spec.set_hostname_as_fqdn = True + job_spec.template.spec.subdomain = subdomain + for idx in range(len(job_spec.template.spec.containers[0].command)): + # CHECK FOR THE ubf_context in the command. + # Replace the UBF context to the one appropriately matching control/worker. + # Since we are passing the `step_cli` one time from the top level to one + # KuberentesJobSet, we need to ensure that UBF context is replaced properly + # in all the worker jobs. + if UBF_CONTROL in job_spec.template.spec.containers[0].command[idx]: + job_spec.template.spec.containers[0].command[idx] = ( + job_spec.template.spec.containers[0] + .command[idx] + .replace(UBF_CONTROL, UBF_CONTROL + " " + "--split-index 0") + ) + + job_spec.template.spec.containers[0].env = ( + job_spec.template.spec.containers[0].env + + _jobset_specific_env_vars(client, jobset_main_addr, master_port, num_parallel) + + [ + client.V1EnvVar( + name="CONTROL_INDEX", + value=str(0), + ) + ] + ) + + # Based on https://github.com/kubernetes-sigs/jobset/blob/v0.5.0/api/jobset/v1alpha2/jobset_types.go#L178 + return dict( + name="control", + template=client.api_client.ApiClient().sanitize_for_serialization( + client.V1JobTemplateSpec( + metadata=client.V1ObjectMeta( + namespace=namespace, + # We don't set any annotations here + # since they have been either set in the JobSpec + # or on the JobSet level + ), + spec=job_spec, + ) + ), + replicas=1, # The control job will always have 1 replica. + ) + + +def get_worker_job( + client, + job_spec, + job_name, + jobset_main_addr, + subdomain, + control_task_id=None, + worker_task_id=None, + replicas=1, + port=None, + num_parallel=None, + namespace=None, +) -> dict: + master_port = port + + job_spec = copy.deepcopy(job_spec) + job_spec.parallelism = 1 + job_spec.completions = 1 + job_spec.template.spec.set_hostname_as_fqdn = True + job_spec.template.spec.subdomain = subdomain + + for idx in range(len(job_spec.template.spec.containers[0].command)): + if control_task_id in job_spec.template.spec.containers[0].command[idx]: + job_spec.template.spec.containers[0].command[idx] = ( + job_spec.template.spec.containers[0] + .command[idx] + .replace(control_task_id, worker_task_id) + ) + # CHECK FOR THE ubf_context in the command. + # Replace the UBF context to the one appropriately matching control/worker. + # Since we are passing the `step_cli` one time from the top level to one + # KuberentesJobSet, we need to ensure that UBF context is replaced properly + # in all the worker jobs. + if UBF_CONTROL in job_spec.template.spec.containers[0].command[idx]: + # Since all command will have a UBF_CONTROL, we need to replace the UBF_CONTROL + # with the actual UBF Context and also ensure that we are setting the correct + # split-index for the worker jobs. + split_index_str = "--split-index `expr $[WORKER_REPLICA_INDEX] + 1`" # This set in the environment variables below + job_spec.template.spec.containers[0].command[idx] = ( + job_spec.template.spec.containers[0] + .command[idx] + .replace(UBF_CONTROL, UBF_TASK + " " + split_index_str) + ) + + job_spec.template.spec.containers[0].env = job_spec.template.spec.containers[ + 0 + ].env + _jobset_specific_env_vars( + client, jobset_main_addr, master_port, num_parallel + ) + + # Based on https://github.com/kubernetes-sigs/jobset/blob/v0.5.0/api/jobset/v1alpha2/jobset_types.go#L178 + return dict( + name=job_name, + template=client.api_client.ApiClient().sanitize_for_serialization( + client.V1JobTemplateSpec( + metadata=client.V1ObjectMeta( + namespace=namespace, + # We don't set any annotations here + # since they have been either set in the JobSpec + # or on the JobSet level + ), + spec=job_spec, + ) + ), + replicas=replicas, + ) + + +def _make_domain_name( + jobset_name, main_job_name, main_job_index, main_pod_index, namespace +): + return "%s-%s-%s-%s.%s.%s.svc.cluster.local" % ( + jobset_name, + main_job_name, + main_job_index, + main_pod_index, + jobset_name, + namespace, + ) + + +class KubernetesJobSet(object): + def __init__( + self, + client, + name=None, + job_spec=None, + namespace=None, + num_parallel=None, + annotations=None, + labels=None, + port=None, + task_id=None, + **kwargs + ): + self._client = client + self._kwargs = kwargs + self._group = KUBERNETES_JOBSET_GROUP + self._version = KUBERNETES_JOBSET_VERSION + self.name = name + + main_job_name = "control" + main_job_index = 0 + main_pod_index = 0 + subdomain = self.name + num_parallel = int(1 if not num_parallel else num_parallel) + self._namespace = namespace + jobset_main_addr = _make_domain_name( + self.name, + main_job_name, + main_job_index, + main_pod_index, + self._namespace, + ) + + annotations = {} if not annotations else annotations + labels = {} if not labels else labels + + if "metaflow/task_id" in annotations: + del annotations["metaflow/task_id"] + + control_job = get_control_job( + client=self._client.get(), + job_spec=job_spec, + jobset_main_addr=jobset_main_addr, + subdomain=subdomain, + port=port, + num_parallel=num_parallel, + namespace=namespace, + ) + worker_task_id = TaskIdConstructor.jobset_worker_id(task_id) + worker_job = get_worker_job( + client=self._client.get(), + job_spec=job_spec, + job_name="worker", + jobset_main_addr=jobset_main_addr, + subdomain=subdomain, + control_task_id=task_id, + worker_task_id=worker_task_id, + replicas=num_parallel - 1, + port=port, + num_parallel=num_parallel, + namespace=namespace, + ) + worker_jobs = [worker_job] + # Based on https://github.com/kubernetes-sigs/jobset/blob/v0.5.0/api/jobset/v1alpha2/jobset_types.go#L163 + _kclient = client.get() + self._jobset = dict( + apiVersion=self._group + "/" + self._version, + kind="JobSet", + metadata=_kclient.api_client.ApiClient().sanitize_for_serialization( + _kclient.V1ObjectMeta( + name=self.name, labels=labels, annotations=annotations + ) + ), + spec=dict( + replicatedJobs=[control_job] + worker_jobs, + suspend=False, + startupPolicy=None, + successPolicy=None, + # The Failure Policy helps setting the number of retries for the jobset. + # It cannot accept a value of 0 for maxRestarts. + # So the attempt needs to be smartly set. + # If there is no retry decorator then we not set maxRestarts and instead we will + # set the attempt statically to 0. Otherwise we will make the job pickup the attempt + # from the `V1EnvVarSource.value_from.V1ObjectFieldSelector.field_path` = "metadata.annotations['jobset.sigs.k8s.io/restart-attempt']" + # failurePolicy={ + # "maxRestarts" : 1 + # }, + # The can be set for ArgoWorkflows + failurePolicy=None, + network=None, + ), + status=None, + ) + + def execute(self): + client = self._client.get() + api_instance = client.CoreV1Api() + + with client.ApiClient() as api_client: + api_instance = client.CustomObjectsApi(api_client) + try: + jobset_obj = api_instance.create_namespaced_custom_object( + group=self._group, + version=self._version, + namespace=self._namespace, + plural="jobsets", + body=self._jobset, + ) + except Exception as e: + raise KubernetesJobsetException( + "Exception when calling CustomObjectsApi->create_namespaced_custom_object: %s\n" + % e + ) + + return RunningJobSet( + client=self._client, + name=jobset_obj["metadata"]["name"], + namespace=jobset_obj["metadata"]["namespace"], + group=self._group, + version=self._version, + )