Skip to content

Commit

Permalink
[@parallel on Kubernetes] support for Jobsets
Browse files Browse the repository at this point in the history
Implementation originates from [#1744]

This commit adds support for @parallel when flows are run `--with kubernetes`
Support for Argo workflows will follow in a separate commit.

A user can run a flow with the following:

    @step
    def start(self):
        self.next(self.parallel_step, num_parallel=3)

    @kubernetes(cpu=1, memory=512)
    @parallel
    @step
    def parallel_step(self):
    ...
  • Loading branch information
shrinandj authored and valayDave committed May 10, 2024
1 parent 8b026c2 commit 1bc4c17
Show file tree
Hide file tree
Showing 8 changed files with 1,096 additions and 217 deletions.
3 changes: 3 additions & 0 deletions metaflow/metaflow_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,9 @@
ARGO_WORKFLOWS_KUBERNETES_SECRETS = from_conf("ARGO_WORKFLOWS_KUBERNETES_SECRETS", "")
ARGO_WORKFLOWS_ENV_VARS_TO_SKIP = from_conf("ARGO_WORKFLOWS_ENV_VARS_TO_SKIP", "")

KUBERNETES_JOBSET_GROUP = from_conf("KUBERNETES_JOBSET_GROUP", "jobset.x-k8s.io")
KUBERNETES_JOBSET_VERSION = from_conf("KUBERNETES_JOBSET_VERSION", "v1alpha2")

##
# Argo Events Configuration
##
Expand Down
5 changes: 5 additions & 0 deletions metaflow/plugins/argo/argo_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -838,6 +838,11 @@ def _dag_templates(self):
def _visit(
node, exit_node=None, templates=None, dag_tasks=None, parent_foreach=None
):
if node.parallel_foreach:
raise ArgoWorkflowsException(
"Deploying flows with @parallel decorator(s) "
"as Argo Workflows is not supported currently."
)
# Every for-each node results in a separate subDAG and an equivalent
# DAGTemplate rooted at the child of the for-each node. Each DAGTemplate
# has a unique name - the top-level DAGTemplate is named as the name of
Expand Down
81 changes: 73 additions & 8 deletions metaflow/plugins/kubernetes/kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import os
import re
import shlex
import copy
import time
from typing import Dict, List, Optional
import uuid
from uuid import uuid4

from metaflow import current, util
Expand Down Expand Up @@ -66,6 +66,12 @@ class KubernetesKilledException(MetaflowException):
headline = "Kubernetes Batch job killed"


def _extract_labels_and_annotations_from_job_spec(job_spec):
annotations = job_spec.template.metadata.annotations
labels = job_spec.template.metadata.labels
return copy.copy(annotations), copy.copy(labels)


class Kubernetes(object):
def __init__(
self,
Expand Down Expand Up @@ -140,9 +146,64 @@ def _command(
return shlex.split('bash -c "%s"' % cmd_str)

def launch_job(self, **kwargs):
self._job = self.create_job(**kwargs).execute()
if (
"num_parallel" in kwargs
and kwargs["num_parallel"]
and int(kwargs["num_parallel"]) > 0
):
job = self.create_job_object(**kwargs)
spec = job.create_job_spec()
# `kwargs["step_cli"]` is setting `ubf_context` as control to ALL pods.
# This will be modified by the KubernetesJobSet object
annotations, labels = _extract_labels_and_annotations_from_job_spec(spec)
self._job = self.create_jobset(
job_spec=spec,
run_id=kwargs["run_id"],
step_name=kwargs["step_name"],
task_id=kwargs["task_id"],
namespace=kwargs["namespace"],
env=kwargs["env"],
num_parallel=kwargs["num_parallel"],
port=kwargs["port"],
annotations=annotations,
labels=labels,
).execute()
else:
kwargs["name_pattern"] = "t-{uid}-".format(uid=str(uuid4())[:8])
self._job = self.create_job_object(**kwargs).k8screate().execute()

def create_jobset(
self,
job_spec=None,
run_id=None,
step_name=None,
task_id=None,
namespace=None,
env=None,
num_parallel=None,
port=None,
annotations=None,
labels=None,
):
if env is None:
env = {}

def create_job(
_prefix = str(uuid4())[:6]
js = KubernetesClient().jobset(
name="js-%s" % _prefix,
run_id=run_id,
task_id=task_id,
step_name=step_name,
namespace=namespace,
labels=self._get_labels(labels),
annotations=annotations,
num_parallel=num_parallel,
job_spec=job_spec,
port=port,
)
return js

def create_job_object(
self,
flow_name,
run_id,
Expand Down Expand Up @@ -176,14 +237,15 @@ def create_job(
labels=None,
shared_memory=None,
port=None,
name_pattern=None,
num_parallel=None,
):
if env is None:
env = {}

job = (
KubernetesClient()
.job(
generate_name="t-{uid}-".format(uid=str(uuid4())[:8]),
generate_name=name_pattern,
namespace=namespace,
service_account=service_account,
secrets=secrets,
Expand Down Expand Up @@ -217,6 +279,7 @@ def create_job(
persistent_volume_claims=persistent_volume_claims,
shared_memory=shared_memory,
port=port,
num_parallel=num_parallel,
)
.environment_variable("METAFLOW_CODE_SHA", code_package_sha)
.environment_variable("METAFLOW_CODE_URL", code_package_url)
Expand Down Expand Up @@ -332,6 +395,9 @@ def create_job(
.label("app.kubernetes.io/part-of", "metaflow")
)

return job

def create_k8sjob(self, job):
return job.create()

def wait(self, stdout_location, stderr_location, echo=None):
Expand Down Expand Up @@ -366,7 +432,7 @@ def wait_for_launch(job):
t = time.time()
time.sleep(update_delay(time.time() - start_time))

prefix = b"[%s] " % util.to_bytes(self._job.id)
_make_prefix = lambda: b"[%s] " % util.to_bytes(self._job.id)

stdout_tail = get_log_tailer(stdout_location, self._datastore.TYPE)
stderr_tail = get_log_tailer(stderr_location, self._datastore.TYPE)
Expand All @@ -376,7 +442,7 @@ def wait_for_launch(job):

# 2) Tail logs until the job has finished
tail_logs(
prefix=prefix,
prefix=_make_prefix(),
stdout_tail=stdout_tail,
stderr_tail=stderr_tail,
echo=echo,
Expand All @@ -392,7 +458,6 @@ def wait_for_launch(job):
# exists prior to calling S3Tail and note the user about
# truncated logs if it doesn't.
# TODO : For hard crashes, we can fetch logs from the pod.

if self._job.has_failed:
exit_code, reason = self._job.reason
msg = next(
Expand Down
12 changes: 12 additions & 0 deletions metaflow/plugins/kubernetes/kubernetes_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from metaflow._vendor import click
from metaflow.exception import METAFLOW_EXIT_DISALLOW_RETRY, CommandException
from metaflow.metadata.util import sync_local_metadata_from_datastore
from metaflow.unbounded_foreach import UBF_CONTROL, UBF_TASK
from metaflow.metaflow_config import DATASTORE_LOCAL_DIR, KUBERNETES_LABELS
from metaflow.mflog import TASK_LOG_SOURCE
import metaflow.tracing as tracing
Expand Down Expand Up @@ -109,6 +110,15 @@ def kubernetes():
)
@click.option("--shared-memory", default=None, help="Size of shared memory in MiB")
@click.option("--port", default=None, help="Port number to expose from the container")
@click.option(
"--ubf-context", default=None, type=click.Choice([None, UBF_CONTROL, UBF_TASK])
)
@click.option(
"--num-parallel",
default=None,
type=int,
help="Number of parallel nodes to run as a multi-node job.",
)
@click.pass_context
def step(
ctx,
Expand Down Expand Up @@ -136,6 +146,7 @@ def step(
tolerations=None,
shared_memory=None,
port=None,
num_parallel=None,
**kwargs
):
def echo(msg, stream="stderr", job_id=None, **kwargs):
Expand Down Expand Up @@ -251,6 +262,7 @@ def _sync_metadata():
tolerations=tolerations,
shared_memory=shared_memory,
port=port,
num_parallel=num_parallel,
)
except Exception as e:
traceback.print_exc(chain=False)
Expand Down
5 changes: 4 additions & 1 deletion metaflow/plugins/kubernetes/kubernetes_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from metaflow.exception import MetaflowException

from .kubernetes_job import KubernetesJob
from .kubernetes_job import KubernetesJob, KubernetesJobSet


CLIENT_REFRESH_INTERVAL_SECONDS = 300
Expand Down Expand Up @@ -61,5 +61,8 @@ def get(self):

return self._client

def jobset(self, **kwargs):
return KubernetesJobSet(self, **kwargs)

def job(self, **kwargs):
return KubernetesJob(self, **kwargs)
43 changes: 37 additions & 6 deletions metaflow/plugins/kubernetes/kubernetes_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@

from ..aws.aws_utils import get_docker_registry, get_ec2_instance_metadata
from .kubernetes import KubernetesException, parse_kube_keyvalue_list
from metaflow.unbounded_foreach import UBF_CONTROL
from .kubernetes_jobsets import TaskIdConstructor

try:
unicode
Expand Down Expand Up @@ -239,12 +241,6 @@ def step_init(self, flow, graph, step, decos, environment, flow_datastore, logge
"Kubernetes. Please use one or the other.".format(step=step)
)

for deco in decos:
if getattr(deco, "IS_PARALLEL", False):
raise KubernetesException(
"@kubernetes does not support parallel execution currently."
)

# Set run time limit for the Kubernetes job.
self.run_time_limit = get_run_time_limit_for_task(decos)
if self.run_time_limit < 60:
Expand Down Expand Up @@ -453,6 +449,24 @@ def task_pre_step(
self._save_logs_sidecar = Sidecar("save_logs_periodically")
self._save_logs_sidecar.start()

num_parallel = None
if hasattr(flow, "_parallel_ubf_iter"):
num_parallel = flow._parallel_ubf_iter.num_parallel

if num_parallel and num_parallel >= 1 and ubf_context == UBF_CONTROL:
control_task_id, worker_task_ids = TaskIdConstructor.join_step_task_ids(
num_parallel
)
mapper_task_ids = [control_task_id] + worker_task_ids
flow._control_mapper_tasks = [
"%s/%s/%s" % (run_id, step_name, mapper_task_id)
for mapper_task_id in mapper_task_ids
]
flow._control_task_is_mapper_zero = True

if num_parallel and num_parallel > 1:
_setup_multinode_environment()

def task_finished(
self, step_name, flow, graph, is_task_ok, retry_count, max_retries
):
Expand Down Expand Up @@ -486,3 +500,20 @@ def _save_package_once(cls, flow_datastore, package):
cls.package_url, cls.package_sha = flow_datastore.save_data(
[package.blob], len_hint=1
)[0]


def _setup_multinode_environment():
import socket

os.environ["MF_PARALLEL_MAIN_IP"] = socket.gethostbyname(os.environ["MASTER_ADDR"])
os.environ["MF_PARALLEL_NUM_NODES"] = os.environ["WORLD_SIZE"]
if os.environ.get("CONTROL_INDEX") is not None:
os.environ["MF_PARALLEL_NODE_INDEX"] = str(0)
elif os.environ.get("WORKER_REPLICA_INDEX") is not None:
os.environ["MF_PARALLEL_NODE_INDEX"] = str(
int(os.environ["WORKER_REPLICA_INDEX"]) + 1
)
else:
raise MetaflowException(
"Jobset related ENV vars called $CONTROL_INDEX or $WORKER_REPLICA_INDEX not found"
)
Loading

0 comments on commit 1bc4c17

Please sign in to comment.