Skip to content

Commit

Permalink
Feat: Improve UX of pytorch-elastic plugin by configuring reasonable …
Browse files Browse the repository at this point in the history
…defaults (#2543)

* Add flag to Elastic and PyTorch task config which configures shared memory volume

Signed-off-by: Fabio Grätz <fabiogratz@googlemail.com>

* Set reasonable default timeouts for pytorch elastic task config

Signed-off-by: Fabio Grätz <fabiogratz@googlemail.com>

* Lint

Signed-off-by: Fabio Grätz <fabiogratz@googlemail.com>

* Lint

Signed-off-by: Fabio Grätz <fabiogratz@googlemail.com>

* Add shm size upper limit to docstring

Signed-off-by: Fabio Grätz <fabiogratz@googlemail.com>

* Correct docstring: multi-threading -> multi-processing

Signed-off-by: Fabio Grätz <fabiogratz@googlemail.com>

* Add kubernetes dep

Signed-off-by: Fabio Grätz <fabiogratz@googlemail.com>

* Refactor the check of num containers as proposed in code review

Signed-off-by: Fabio Grätz <fabiogratz@googlemail.com>

* Lint

Signed-off-by: Fabio Grätz <fabiogratz@googlemail.com>

* Test that explicitly configured pod template with shm vol is not removed if disable in task config

Signed-off-by: Fabio Grätz <fabiogratz@googlemail.com>

* Raise if the user explicitly configured shm vol mount and still sets task config to add it

Signed-off-by: Fabio Grätz <fabiogratz@googlemail.com>

---------

Signed-off-by: Fabio Grätz <fabiogratz@googlemail.com>
Co-authored-by: Fabio Grätz <fabiogratz@googlemail.com>
  • Loading branch information
fg91 and Fabio Grätz authored Jul 21, 2024
1 parent 551924a commit efed441
Show file tree
Hide file tree
Showing 5 changed files with 222 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from kubernetes.client import V1Container, V1EmptyDirVolumeSource, V1PodSpec, V1Volume, V1VolumeMount

from flytekit.core.pod_template import PodTemplate


def add_shared_mem_volume_to_pod_template(pod_template: PodTemplate) -> None:
"""Add shared memory volume and volume mount to the pod template."""
mount_path = "/dev/shm"
shm_volume = V1Volume(name="shm", empty_dir=V1EmptyDirVolumeSource(medium="Memory"))
shm_volume_mount = V1VolumeMount(name="shm", mount_path=mount_path)

if pod_template.pod_spec is None:
pod_template.pod_spec = V1PodSpec()

if pod_template.pod_spec.containers is None:
pod_template.pod_spec.containers = []

if pod_template.pod_spec.volumes is None:
pod_template.pod_spec.volumes = []

pod_template.pod_spec.volumes.append(shm_volume)

num_containers = len(pod_template.pod_spec.containers)

if num_containers >= 2:
raise ValueError(
"When configuring a pod template with multiple containers, please set `increase_shared_mem=False` "
"in the task config and if required mount a volume to increase the shared memory size in the respective "
"container yourself."
)

if num_containers != 1:
pod_template.pod_spec.containers.append(V1Container(name="primary"))

if pod_template.pod_spec.containers[0].volume_mounts is None:
pod_template.pod_spec.containers[0].volume_mounts = []

has_shared_mem_vol_mount = any(
[v.mount_path == mount_path for v in pod_template.pod_spec.containers[0].volume_mounts]
)
if has_shared_mem_vol_mount:
raise ValueError(
"A shared memory volume mount is already configured in the pod template. "
"Please remove the volume mount or set `increase_shared_mem=False` in the task config."
)

pod_template.pod_spec.containers[0].volume_mounts.append(shm_volume_mount)
28 changes: 27 additions & 1 deletion plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@
import flytekit
from flytekit import PythonFunctionTask, Resources, lazy_module
from flytekit.configuration import SerializationSettings
from flytekit.core.pod_template import PodTemplate
from flytekit.core.resources import convert_resources_to_resource_model
from flytekit.exceptions.user import FlyteRecoverableException
from flytekit.extend import IgnoreOutputs, TaskPlugins
from flytekit.loggers import logger

from .error_handling import create_recoverable_error_file, is_recoverable_worker_error
from .pod_template import add_shared_mem_volume_to_pod_template

cloudpickle = lazy_module("cloudpickle")

Expand Down Expand Up @@ -103,13 +105,19 @@ class PyTorch(object):
worker: Configuration for the worker replica group.
run_policy: Configuration for the run policy.
num_workers: [DEPRECATED] This argument is deprecated. Use `worker.replicas` instead.
increase_shared_mem (bool): PyTorch uses shared memory to share data between processes. If torch multiprocessing is used
(e.g. for multi-processed data loaders) the default shared memory segment size that the container runs with might not be enough
and and one might have to increase the shared memory size. This option configures the task's pod template to mount
an `emptyDir` volume with medium `Memory` to to `/dev/shm`.
The shared memory size upper limit is the sum of the memory limits of the containers in the pod.
"""

master: Master = field(default_factory=lambda: Master())
worker: Worker = field(default_factory=lambda: Worker())
run_policy: Optional[RunPolicy] = None
# Support v0 config for backwards compatibility
num_workers: Optional[int] = None
increase_shared_mem: bool = True


@dataclass
Expand All @@ -134,6 +142,14 @@ class Elastic(object):
max_restarts (int): Maximum number of worker group restarts before failing.
rdzv_configs (Dict[str, Any]): Additional rendezvous configs to pass to torch elastic, e.g. `{"timeout": 1200, "join_timeout": 900}`.
See `torch.distributed.launcher.api.LaunchConfig` and `torch.distributed.elastic.rendezvous.dynamic_rendezvous.create_handler`.
Default timeouts are set to 15 minutes to account for the fact that some workers might start faster than others: Some pods might
be assigned to a running node which might have the image in its cache while other workers might require a node scale up and image pull.
increase_shared_mem (bool): PyTorch uses shared memory to share data between processes. If torch multiprocessing is used
(e.g. for multi-processed data loaders) the default shared memory segment size that the container runs with might not be enough
and and one might have to increase the shared memory size. This option configures the task's pod template to mount
an `emptyDir` volume with medium `Memory` to to `/dev/shm`.
The shared memory size upper limit is the sum of the memory limits of the containers in the pod.
run_policy: Configuration for the run policy.
"""

Expand All @@ -142,7 +158,8 @@ class Elastic(object):
start_method: str = "spawn"
monitor_interval: int = 5
max_restarts: int = 0
rdzv_configs: Dict[str, Any] = field(default_factory=dict)
rdzv_configs: Dict[str, Any] = field(default_factory=lambda: {"timeout": 900, "join_timeout": 900})
increase_shared_mem: bool = True
run_policy: Optional[RunPolicy] = None


Expand Down Expand Up @@ -171,6 +188,10 @@ def __init__(self, task_config: PyTorch, task_function: Callable, **kwargs):
task_type_version=1,
**kwargs,
)
if self.task_config.increase_shared_mem:
if self.pod_template is None:
self.pod_template = PodTemplate()
add_shared_mem_volume_to_pod_template(self.pod_template)

def _convert_replica_spec(
self, replica_config: Union[Master, Worker]
Expand Down Expand Up @@ -308,6 +329,11 @@ def __init__(self, task_config: Elastic, task_function: Callable, **kwargs):
"""
self.rdzv_backend = "c10d"

if self.task_config.increase_shared_mem:
if self.pod_template is None:
self.pod_template = PodTemplate()
add_shared_mem_volume_to_pod_template(self.pod_template)

def _execute(self, **kwargs) -> Any:
"""
Execute the task function using torch distributed's `elastic_launch`.
Expand Down
2 changes: 1 addition & 1 deletion plugins/flytekit-kf-pytorch/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

microlib_name = f"flytekitplugins-{PLUGIN_NAME}"

plugin_requires = ["cloudpickle", "flyteidl>=1.5.1", "flytekit>=1.6.1"]
plugin_requires = ["cloudpickle", "flyteidl>=1.5.1", "flytekit>=1.6.1", "kubernetes"]

__version__ = "0.0.0+develop"

Expand Down
9 changes: 9 additions & 0 deletions plugins/flytekit-kf-pytorch/tests/test_elastic_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,15 @@ def wf(recoverable: bool):
wf(recoverable=recoverable)


def test_default_timeouts():
"""Test that default timeouts are set for the elastic task."""
@task(task_config=Elastic(nnodes=1))
def test_task():
pass

assert test_task.task_config.rdzv_configs == {"join_timeout": 900, "timeout": 900}


def test_run_policy() -> None:
"""Test that run policy is propagated to custom spec."""

Expand Down
138 changes: 138 additions & 0 deletions plugins/flytekit-kf-pytorch/tests/test_shared.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
"""Test functionality that is shared between the pytorch and pytorch-elastic tasks."""

from contextlib import nullcontext
from typing import Union

import pytest
from flytekitplugins.kfpytorch.task import Elastic, PyTorch
from kubernetes.client import V1Container, V1EmptyDirVolumeSource, V1PodSpec, V1Volume, V1VolumeMount

from flytekit import PodTemplate, task


@pytest.mark.parametrize(
"task_config, pod_template, needs_shm_volume, raises",
[
# Test that by default shared memory volume is added
(PyTorch(num_workers=3), None, True, False),
(Elastic(nnodes=2, increase_shared_mem=True), None, True, False),
# Test disabling shared memory volume
(PyTorch(num_workers=3, increase_shared_mem=False), None, False, False),
(Elastic(nnodes=2, increase_shared_mem=False), None, False, False),
# Test that explicitly passed pod template does not break adding shm volume
(Elastic(nnodes=2, increase_shared_mem=True), PodTemplate(), True, False),
# Test that pod template with container does not break adding shm volume
(
Elastic(nnodes=2),
PodTemplate(
pod_spec=V1PodSpec(containers=[V1Container(name="primary")]),
),
True,
False,
),
# Test that pod template with volume/volume mount does not break adding shm volume
(
Elastic(nnodes=2),
PodTemplate(
pod_spec=V1PodSpec(
containers=[
V1Container(name="primary", volume_mounts=[V1VolumeMount(name="foo", mount_path="/bar")])
],
volumes=[V1Volume(name="foo")],
),
),
True,
False,
),
# Test that pod template with multiple containers raises an error
(
Elastic(nnodes=2),
PodTemplate(
pod_spec=V1PodSpec(
containers=[
V1Container(name="primary"),
V1Container(name="secondary"),
]
),
),
True,
True,
),
# Test that explicitly configured pod template with shared memory volume is not removed if `increase_shared_mem=False`
(
Elastic(nnodes=2, increase_shared_mem=False),
PodTemplate(
pod_spec=V1PodSpec(
containers=[
V1Container(name="primary", volume_mounts=[V1VolumeMount(name="shm", mount_path="/dev/shm")]),
],
volumes=[V1Volume(name="shm", empty_dir=V1EmptyDirVolumeSource(medium="Memory"))],
),
),
True,
False,
),
# Test that we raise if the user explicitly configured a shared memory volume and still configures the task config to add it
(
Elastic(nnodes=2, increase_shared_mem=True),
PodTemplate(
pod_spec=V1PodSpec(
containers=[
V1Container(name="primary", volume_mounts=[V1VolumeMount(name="shm", mount_path="/dev/shm")]),
],
volumes=[V1Volume(name="shm", empty_dir=V1EmptyDirVolumeSource(medium="Memory"))],
),
),
True,
True,
),
],
)
def test_task_shared_memory(
task_config: Union[Elastic, PyTorch], pod_template: PodTemplate, needs_shm_volume: bool, raises: bool
):
"""Test that the task pod template is configured with a shared memory volume if needed."""

expected_volume = V1Volume(name="shm", empty_dir=V1EmptyDirVolumeSource(medium="Memory"))
expected_volume_mount = V1VolumeMount(name="shm", mount_path="/dev/shm")

with pytest.raises(ValueError) if raises else nullcontext():

@task(
task_config=task_config,
pod_template=pod_template,
)
def test_task() -> None:
pass

if needs_shm_volume:
assert test_task.pod_template is not None
assert test_task.pod_template.pod_spec is not None
assert test_task.pod_template.pod_spec.volumes is not None
assert test_task.pod_template.pod_spec.containers is not None
assert test_task.pod_template.pod_spec.containers[0].volume_mounts is not None

assert any([v == expected_volume for v in test_task.pod_template.pod_spec.volumes])
assert any(
[v == expected_volume_mount for v in test_task.pod_template.pod_spec.containers[0].volume_mounts]
)

else:
# Check that the shared memory volume + volume mount is not added
no_pod_template = test_task.pod_template is None
no_pod_spec = no_pod_template or test_task.pod_template.pod_spec is None
no_volumes = no_pod_spec or test_task.pod_template.pod_spec.volumes is None
no_containers = no_pod_spec or len(test_task.pod_template.pod_spec.containers) == 0
no_volume_mounts = no_containers or test_task.pod_template.pod_spec.containers[0].volume_mounts is None
empty_volume_mounts = (
no_volume_mounts or len(test_task.pod_template.pod_spec.containers[0].volume_mounts) == 0
)
no_shm_volume_condition = no_volumes or not any(
[v == expected_volume for v in test_task.pod_template.pod_spec.volumes]
)
no_shm_volume_mount_condition = empty_volume_mounts or not any(
[v == expected_volume_mount for v in test_task.pod_template.pod_spec.containers[0].volume_mounts]
)

assert no_shm_volume_condition
assert no_shm_volume_mount_condition

0 comments on commit efed441

Please sign in to comment.