diff --git a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/pod_template.py b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/pod_template.py new file mode 100644 index 0000000000..8d8567d3e7 --- /dev/null +++ b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/pod_template.py @@ -0,0 +1,47 @@ +from kubernetes.client import V1Container, V1EmptyDirVolumeSource, V1PodSpec, V1Volume, V1VolumeMount + +from flytekit.core.pod_template import PodTemplate + + +def add_shared_mem_volume_to_pod_template(pod_template: PodTemplate) -> None: + """Add shared memory volume and volume mount to the pod template.""" + mount_path = "/dev/shm" + shm_volume = V1Volume(name="shm", empty_dir=V1EmptyDirVolumeSource(medium="Memory")) + shm_volume_mount = V1VolumeMount(name="shm", mount_path=mount_path) + + if pod_template.pod_spec is None: + pod_template.pod_spec = V1PodSpec() + + if pod_template.pod_spec.containers is None: + pod_template.pod_spec.containers = [] + + if pod_template.pod_spec.volumes is None: + pod_template.pod_spec.volumes = [] + + pod_template.pod_spec.volumes.append(shm_volume) + + num_containers = len(pod_template.pod_spec.containers) + + if num_containers >= 2: + raise ValueError( + "When configuring a pod template with multiple containers, please set `increase_shared_mem=False` " + "in the task config and if required mount a volume to increase the shared memory size in the respective " + "container yourself." + ) + + if num_containers != 1: + pod_template.pod_spec.containers.append(V1Container(name="primary")) + + if pod_template.pod_spec.containers[0].volume_mounts is None: + pod_template.pod_spec.containers[0].volume_mounts = [] + + has_shared_mem_vol_mount = any( + [v.mount_path == mount_path for v in pod_template.pod_spec.containers[0].volume_mounts] + ) + if has_shared_mem_vol_mount: + raise ValueError( + "A shared memory volume mount is already configured in the pod template. " + "Please remove the volume mount or set `increase_shared_mem=False` in the task config." + ) + + pod_template.pod_spec.containers[0].volume_mounts.append(shm_volume_mount) diff --git a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py index 73d5e07c1b..0fab224fa2 100644 --- a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py +++ b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py @@ -15,12 +15,14 @@ import flytekit from flytekit import PythonFunctionTask, Resources, lazy_module from flytekit.configuration import SerializationSettings +from flytekit.core.pod_template import PodTemplate from flytekit.core.resources import convert_resources_to_resource_model from flytekit.exceptions.user import FlyteRecoverableException from flytekit.extend import IgnoreOutputs, TaskPlugins from flytekit.loggers import logger from .error_handling import create_recoverable_error_file, is_recoverable_worker_error +from .pod_template import add_shared_mem_volume_to_pod_template cloudpickle = lazy_module("cloudpickle") @@ -103,6 +105,11 @@ class PyTorch(object): worker: Configuration for the worker replica group. run_policy: Configuration for the run policy. num_workers: [DEPRECATED] This argument is deprecated. Use `worker.replicas` instead. + increase_shared_mem (bool): PyTorch uses shared memory to share data between processes. If torch multiprocessing is used + (e.g. for multi-processed data loaders) the default shared memory segment size that the container runs with might not be enough + and and one might have to increase the shared memory size. This option configures the task's pod template to mount + an `emptyDir` volume with medium `Memory` to to `/dev/shm`. + The shared memory size upper limit is the sum of the memory limits of the containers in the pod. """ master: Master = field(default_factory=lambda: Master()) @@ -110,6 +117,7 @@ class PyTorch(object): run_policy: Optional[RunPolicy] = None # Support v0 config for backwards compatibility num_workers: Optional[int] = None + increase_shared_mem: bool = True @dataclass @@ -134,6 +142,14 @@ class Elastic(object): max_restarts (int): Maximum number of worker group restarts before failing. rdzv_configs (Dict[str, Any]): Additional rendezvous configs to pass to torch elastic, e.g. `{"timeout": 1200, "join_timeout": 900}`. See `torch.distributed.launcher.api.LaunchConfig` and `torch.distributed.elastic.rendezvous.dynamic_rendezvous.create_handler`. + Default timeouts are set to 15 minutes to account for the fact that some workers might start faster than others: Some pods might + be assigned to a running node which might have the image in its cache while other workers might require a node scale up and image pull. + + increase_shared_mem (bool): PyTorch uses shared memory to share data between processes. If torch multiprocessing is used + (e.g. for multi-processed data loaders) the default shared memory segment size that the container runs with might not be enough + and and one might have to increase the shared memory size. This option configures the task's pod template to mount + an `emptyDir` volume with medium `Memory` to to `/dev/shm`. + The shared memory size upper limit is the sum of the memory limits of the containers in the pod. run_policy: Configuration for the run policy. """ @@ -142,7 +158,8 @@ class Elastic(object): start_method: str = "spawn" monitor_interval: int = 5 max_restarts: int = 0 - rdzv_configs: Dict[str, Any] = field(default_factory=dict) + rdzv_configs: Dict[str, Any] = field(default_factory=lambda: {"timeout": 900, "join_timeout": 900}) + increase_shared_mem: bool = True run_policy: Optional[RunPolicy] = None @@ -171,6 +188,10 @@ def __init__(self, task_config: PyTorch, task_function: Callable, **kwargs): task_type_version=1, **kwargs, ) + if self.task_config.increase_shared_mem: + if self.pod_template is None: + self.pod_template = PodTemplate() + add_shared_mem_volume_to_pod_template(self.pod_template) def _convert_replica_spec( self, replica_config: Union[Master, Worker] @@ -308,6 +329,11 @@ def __init__(self, task_config: Elastic, task_function: Callable, **kwargs): """ self.rdzv_backend = "c10d" + if self.task_config.increase_shared_mem: + if self.pod_template is None: + self.pod_template = PodTemplate() + add_shared_mem_volume_to_pod_template(self.pod_template) + def _execute(self, **kwargs) -> Any: """ Execute the task function using torch distributed's `elastic_launch`. diff --git a/plugins/flytekit-kf-pytorch/setup.py b/plugins/flytekit-kf-pytorch/setup.py index cc90e0b299..317ca7b8a0 100644 --- a/plugins/flytekit-kf-pytorch/setup.py +++ b/plugins/flytekit-kf-pytorch/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["cloudpickle", "flyteidl>=1.5.1", "flytekit>=1.6.1"] +plugin_requires = ["cloudpickle", "flyteidl>=1.5.1", "flytekit>=1.6.1", "kubernetes"] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-kf-pytorch/tests/test_elastic_task.py b/plugins/flytekit-kf-pytorch/tests/test_elastic_task.py index 9cb62b993c..b56fc0aa08 100644 --- a/plugins/flytekit-kf-pytorch/tests/test_elastic_task.py +++ b/plugins/flytekit-kf-pytorch/tests/test_elastic_task.py @@ -196,6 +196,15 @@ def wf(recoverable: bool): wf(recoverable=recoverable) +def test_default_timeouts(): + """Test that default timeouts are set for the elastic task.""" + @task(task_config=Elastic(nnodes=1)) + def test_task(): + pass + + assert test_task.task_config.rdzv_configs == {"join_timeout": 900, "timeout": 900} + + def test_run_policy() -> None: """Test that run policy is propagated to custom spec.""" diff --git a/plugins/flytekit-kf-pytorch/tests/test_shared.py b/plugins/flytekit-kf-pytorch/tests/test_shared.py new file mode 100644 index 0000000000..b86f9a73d9 --- /dev/null +++ b/plugins/flytekit-kf-pytorch/tests/test_shared.py @@ -0,0 +1,138 @@ +"""Test functionality that is shared between the pytorch and pytorch-elastic tasks.""" + +from contextlib import nullcontext +from typing import Union + +import pytest +from flytekitplugins.kfpytorch.task import Elastic, PyTorch +from kubernetes.client import V1Container, V1EmptyDirVolumeSource, V1PodSpec, V1Volume, V1VolumeMount + +from flytekit import PodTemplate, task + + +@pytest.mark.parametrize( + "task_config, pod_template, needs_shm_volume, raises", + [ + # Test that by default shared memory volume is added + (PyTorch(num_workers=3), None, True, False), + (Elastic(nnodes=2, increase_shared_mem=True), None, True, False), + # Test disabling shared memory volume + (PyTorch(num_workers=3, increase_shared_mem=False), None, False, False), + (Elastic(nnodes=2, increase_shared_mem=False), None, False, False), + # Test that explicitly passed pod template does not break adding shm volume + (Elastic(nnodes=2, increase_shared_mem=True), PodTemplate(), True, False), + # Test that pod template with container does not break adding shm volume + ( + Elastic(nnodes=2), + PodTemplate( + pod_spec=V1PodSpec(containers=[V1Container(name="primary")]), + ), + True, + False, + ), + # Test that pod template with volume/volume mount does not break adding shm volume + ( + Elastic(nnodes=2), + PodTemplate( + pod_spec=V1PodSpec( + containers=[ + V1Container(name="primary", volume_mounts=[V1VolumeMount(name="foo", mount_path="/bar")]) + ], + volumes=[V1Volume(name="foo")], + ), + ), + True, + False, + ), + # Test that pod template with multiple containers raises an error + ( + Elastic(nnodes=2), + PodTemplate( + pod_spec=V1PodSpec( + containers=[ + V1Container(name="primary"), + V1Container(name="secondary"), + ] + ), + ), + True, + True, + ), + # Test that explicitly configured pod template with shared memory volume is not removed if `increase_shared_mem=False` + ( + Elastic(nnodes=2, increase_shared_mem=False), + PodTemplate( + pod_spec=V1PodSpec( + containers=[ + V1Container(name="primary", volume_mounts=[V1VolumeMount(name="shm", mount_path="/dev/shm")]), + ], + volumes=[V1Volume(name="shm", empty_dir=V1EmptyDirVolumeSource(medium="Memory"))], + ), + ), + True, + False, + ), + # Test that we raise if the user explicitly configured a shared memory volume and still configures the task config to add it + ( + Elastic(nnodes=2, increase_shared_mem=True), + PodTemplate( + pod_spec=V1PodSpec( + containers=[ + V1Container(name="primary", volume_mounts=[V1VolumeMount(name="shm", mount_path="/dev/shm")]), + ], + volumes=[V1Volume(name="shm", empty_dir=V1EmptyDirVolumeSource(medium="Memory"))], + ), + ), + True, + True, + ), + ], +) +def test_task_shared_memory( + task_config: Union[Elastic, PyTorch], pod_template: PodTemplate, needs_shm_volume: bool, raises: bool +): + """Test that the task pod template is configured with a shared memory volume if needed.""" + + expected_volume = V1Volume(name="shm", empty_dir=V1EmptyDirVolumeSource(medium="Memory")) + expected_volume_mount = V1VolumeMount(name="shm", mount_path="/dev/shm") + + with pytest.raises(ValueError) if raises else nullcontext(): + + @task( + task_config=task_config, + pod_template=pod_template, + ) + def test_task() -> None: + pass + + if needs_shm_volume: + assert test_task.pod_template is not None + assert test_task.pod_template.pod_spec is not None + assert test_task.pod_template.pod_spec.volumes is not None + assert test_task.pod_template.pod_spec.containers is not None + assert test_task.pod_template.pod_spec.containers[0].volume_mounts is not None + + assert any([v == expected_volume for v in test_task.pod_template.pod_spec.volumes]) + assert any( + [v == expected_volume_mount for v in test_task.pod_template.pod_spec.containers[0].volume_mounts] + ) + + else: + # Check that the shared memory volume + volume mount is not added + no_pod_template = test_task.pod_template is None + no_pod_spec = no_pod_template or test_task.pod_template.pod_spec is None + no_volumes = no_pod_spec or test_task.pod_template.pod_spec.volumes is None + no_containers = no_pod_spec or len(test_task.pod_template.pod_spec.containers) == 0 + no_volume_mounts = no_containers or test_task.pod_template.pod_spec.containers[0].volume_mounts is None + empty_volume_mounts = ( + no_volume_mounts or len(test_task.pod_template.pod_spec.containers[0].volume_mounts) == 0 + ) + no_shm_volume_condition = no_volumes or not any( + [v == expected_volume for v in test_task.pod_template.pod_spec.volumes] + ) + no_shm_volume_mount_condition = empty_volume_mounts or not any( + [v == expected_volume_mount for v in test_task.pod_template.pod_spec.containers[0].volume_mounts] + ) + + assert no_shm_volume_condition + assert no_shm_volume_mount_condition