Skip to content

Commit

Permalink
Add run policy (#2555)
Browse files Browse the repository at this point in the history
Signed-off-by: bugra.gedik <bugra.gedik@predera.ai>
Co-authored-by: bugra.gedik <bugra.gedik@predera.ai>
  • Loading branch information
bgedik and bugra.gedik authored Jul 3, 2024
1 parent 153b0df commit b82a25f
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 11 deletions.
29 changes: 19 additions & 10 deletions plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ class PyTorch(object):

master: Master = field(default_factory=lambda: Master())
worker: Worker = field(default_factory=lambda: Worker())
run_policy: Optional[RunPolicy] = field(default_factory=lambda: None)
run_policy: Optional[RunPolicy] = None
# Support v0 config for backwards compatibility
num_workers: Optional[int] = None

Expand All @@ -130,6 +130,7 @@ class Elastic(object):
max_restarts (int): Maximum number of worker group restarts before failing.
rdzv_configs (Dict[str, Any]): Additional rendezvous configs to pass to torch elastic, e.g. `{"timeout": 1200, "join_timeout": 900}`.
See `torch.distributed.launcher.api.LaunchConfig` and `torch.distributed.elastic.rendezvous.dynamic_rendezvous.create_handler`.
run_policy: Configuration for the run policy.
"""

nnodes: Union[int, str] = 1
Expand All @@ -138,6 +139,7 @@ class Elastic(object):
monitor_interval: int = 5
max_restarts: int = 0
rdzv_configs: Dict[str, Any] = field(default_factory=dict)
run_policy: Optional[RunPolicy] = None


class PyTorchFunctionTask(PythonFunctionTask[PyTorch]):
Expand Down Expand Up @@ -181,21 +183,15 @@ def _convert_replica_spec(
restart_policy=replica_config.restart_policy.value if replica_config.restart_policy else None,
)

def _convert_run_policy(self, run_policy: RunPolicy) -> kubeflow_common.RunPolicy:
return kubeflow_common.RunPolicy(
clean_pod_policy=run_policy.clean_pod_policy.value if run_policy.clean_pod_policy else None,
ttl_seconds_after_finished=run_policy.ttl_seconds_after_finished,
active_deadline_seconds=run_policy.active_deadline_seconds,
backoff_limit=run_policy.backoff_limit,
)

def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
worker = self._convert_replica_spec(self.task_config.worker)
# support v0 config for backwards compatibility
if self.task_config.num_workers:
worker.replicas = self.task_config.num_workers

run_policy = self._convert_run_policy(self.task_config.run_policy) if self.task_config.run_policy else None
run_policy = (
_convert_run_policy_to_flyte_idl(self.task_config.run_policy) if self.task_config.run_policy else None
)
pytorch_job = pytorch_task.DistributedPyTorchTrainingTask(
worker_replicas=worker,
master_replicas=self._convert_replica_spec(self.task_config.master),
Expand Down Expand Up @@ -263,6 +259,15 @@ def spawn_helper(
return ElasticWorkerResult(return_value=return_val, decks=flytekit.current_context().decks)


def _convert_run_policy_to_flyte_idl(run_policy: RunPolicy) -> kubeflow_common.RunPolicy:
return kubeflow_common.RunPolicy(
clean_pod_policy=run_policy.clean_pod_policy.value if run_policy.clean_pod_policy else None,
ttl_seconds_after_finished=run_policy.ttl_seconds_after_finished,
active_deadline_seconds=run_policy.active_deadline_seconds,
backoff_limit=run_policy.backoff_limit,
)


class PytorchElasticFunctionTask(PythonFunctionTask[Elastic]):
"""
Plugin for distributed training with torch elastic/torchrun (see
Expand Down Expand Up @@ -445,11 +450,15 @@ def get_custom(self, settings: SerializationSettings) -> Optional[Dict[str, Any]
nproc_per_node=self.task_config.nproc_per_node,
max_restarts=self.task_config.max_restarts,
)
run_policy = (
_convert_run_policy_to_flyte_idl(self.task_config.run_policy) if self.task_config.run_policy else None
)
job = pytorch_task.DistributedPyTorchTrainingTask(
worker_replicas=pytorch_task.DistributedPyTorchTrainingReplicaSpec(
replicas=self.max_nodes,
),
elastic_config=elastic_config,
run_policy=run_policy,
)
return MessageToDict(job)

Expand Down
27 changes: 26 additions & 1 deletion plugins/flytekit-kf-pytorch/tests/test_elastic_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
import torch
import torch.distributed as dist
from dataclasses_json import DataClassJsonMixin
from flytekitplugins.kfpytorch.task import Elastic
from flytekitplugins.kfpytorch.task import CleanPodPolicy, Elastic, RunPolicy

import flytekit
from flytekit import task, workflow
from flytekit.configuration import SerializationSettings
from flytekit.exceptions.user import FlyteRecoverableException


Expand Down Expand Up @@ -187,3 +188,27 @@ def wf(recoverable: bool):
else:
with pytest.raises(RuntimeError):
wf(recoverable=recoverable)


def test_run_policy() -> None:
"""Test that run policy is propagated to custom spec."""

run_policy = RunPolicy(
clean_pod_policy=CleanPodPolicy.ALL,
ttl_seconds_after_finished=10 * 60,
active_deadline_seconds=36000,
backoff_limit=None,
)

# nnodes must be > 1 to get pytorchjob spec
@task(task_config=Elastic(nnodes=2, nproc_per_node=2, run_policy=run_policy))
def test_task():
pass

spec = test_task.get_custom(SerializationSettings(image_config=None))

assert spec["runPolicy"] == {
"cleanPodPolicy": "CLEANPOD_POLICY_ALL",
"ttlSecondsAfterFinished": 600,
"activeDeadlineSeconds": 36000,
}

0 comments on commit b82a25f

Please sign in to comment.