From 79970bff9973e2e29274d2d776c4a5e7530194cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabio=20Gr=C3=A4tz?= Date: Mon, 1 Jul 2024 22:10:27 +0200 Subject: [PATCH 01/11] Add flag to Elastic and PyTorch task config which configures shared memory volume MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Fabio Grätz --- .../flytekitplugins/kfpytorch/pod_template.py | 36 ++++++ .../flytekitplugins/kfpytorch/task.py | 22 ++++ .../flytekit-kf-pytorch/tests/test_shared.py | 109 ++++++++++++++++++ 3 files changed, 167 insertions(+) create mode 100644 plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/pod_template.py create mode 100644 plugins/flytekit-kf-pytorch/tests/test_shared.py diff --git a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/pod_template.py b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/pod_template.py new file mode 100644 index 0000000000..46238e1aa9 --- /dev/null +++ b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/pod_template.py @@ -0,0 +1,36 @@ +from flytekit.core.pod_template import PodTemplate + +from kubernetes.client import V1Container, V1PodSpec, V1Volume, V1VolumeMount, V1EmptyDirVolumeSource + + +def add_shared_mem_volume_to_pod_template(pod_template: PodTemplate) -> None: + """Add shared memory volume and volume mount to the pod template.""" + shm_volume = V1Volume(name="shm", empty_dir=V1EmptyDirVolumeSource(medium="Memory")) + shm_volume_mount = V1VolumeMount(name="shm", mount_path="/dev/shm") + + if pod_template.pod_spec is None: + pod_template.pod_spec = V1PodSpec() + + if pod_template.pod_spec.containers is None: + pod_template.pod_spec.containers = [] + + if pod_template.pod_spec.volumes is None: + pod_template.pod_spec.volumes = [] + pod_template.pod_spec.volumes.append(shm_volume) + + num_containers = len(pod_template.pod_spec.containers) + if num_containers == 0: + pod_template.pod_spec.containers.append(V1Container(name="primary")) + elif num_containers == 1: + pass + else: + raise ValueError( + "When configuring a pod template with multiple containers, please set `increase_shared_mem=False` " + "in the task config and if required mount a volume to increase the shared memory size in the respective " + "container yourself." + ) + + if pod_template.pod_spec.containers[0].volume_mounts is None: + pod_template.pod_spec.containers[0].volume_mounts = [] + + pod_template.pod_spec.containers[0].volume_mounts.append(shm_volume_mount) diff --git a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py index 94b575e2a9..3d5ec5a092 100644 --- a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py +++ b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py @@ -15,12 +15,15 @@ import flytekit from flytekit import PythonFunctionTask, Resources, lazy_module from flytekit.configuration import SerializationSettings +from flytekit.core.pod_template import PodTemplate from flytekit.core.resources import convert_resources_to_resource_model from flytekit.exceptions.user import FlyteRecoverableException from flytekit.extend import IgnoreOutputs, TaskPlugins from flytekit.loggers import logger from .error_handling import create_recoverable_error_file, is_recoverable_worker_error +from .pod_template import add_shared_mem_volume_to_pod_template + cloudpickle = lazy_module("cloudpickle") @@ -103,6 +106,10 @@ class PyTorch(object): worker: Configuration for the worker replica group. run_policy: Configuration for the run policy. num_workers: [DEPRECATED] This argument is deprecated. Use `worker.replicas` instead. + increase_shared_mem (bool): PyTorch uses shared memory to share data between processes. If torch multiprocessing is used + (e.g. for multithreaded data loaders) the default shared memory segment size that the container runs with might not be enough + and and one might have to increase the shared memory size. This option configures the task's pod template to mount + an `emptyDir` volume with medium `Memory` to to `/dev/shm`. """ master: Master = field(default_factory=lambda: Master()) @@ -110,6 +117,7 @@ class PyTorch(object): run_policy: Optional[RunPolicy] = field(default_factory=lambda: None) # Support v0 config for backwards compatibility num_workers: Optional[int] = None + increase_shared_mem: bool = True @dataclass @@ -130,6 +138,10 @@ class Elastic(object): max_restarts (int): Maximum number of worker group restarts before failing. rdzv_configs (Dict[str, Any]): Additional rendezvous configs to pass to torch elastic, e.g. `{"timeout": 1200, "join_timeout": 900}`. See `torch.distributed.launcher.api.LaunchConfig` and `torch.distributed.elastic.rendezvous.dynamic_rendezvous.create_handler`. + increase_shared_mem (bool): PyTorch uses shared memory to share data between processes. If torch multiprocessing is used + (e.g. for multithreaded data loaders) the default shared memory segment size that the container runs with might not be enough + and and one might have to increase the shared memory size. This option configures the task's pod template to mount + an `emptyDir` volume with medium `Memory` to to `/dev/shm`. """ nnodes: Union[int, str] = 1 @@ -138,6 +150,7 @@ class Elastic(object): monitor_interval: int = 5 max_restarts: int = 0 rdzv_configs: Dict[str, Any] = field(default_factory=dict) + increase_shared_mem: bool = True class PyTorchFunctionTask(PythonFunctionTask[PyTorch]): @@ -165,6 +178,10 @@ def __init__(self, task_config: PyTorch, task_function: Callable, **kwargs): task_type_version=1, **kwargs, ) + if self.task_config.increase_shared_mem: + if self.pod_template is None: + self.pod_template = PodTemplate() + add_shared_mem_volume_to_pod_template(self.pod_template) def _convert_replica_spec( self, replica_config: Union[Master, Worker] @@ -299,6 +316,11 @@ def __init__(self, task_config: Elastic, task_function: Callable, **kwargs): """ self.rdzv_backend = "c10d" + if self.task_config.increase_shared_mem: + if self.pod_template is None: + self.pod_template = PodTemplate() + add_shared_mem_volume_to_pod_template(self.pod_template) + def _execute(self, **kwargs) -> Any: """ Execute the task function using torch distributed's `elastic_launch`. diff --git a/plugins/flytekit-kf-pytorch/tests/test_shared.py b/plugins/flytekit-kf-pytorch/tests/test_shared.py new file mode 100644 index 0000000000..544d50135b --- /dev/null +++ b/plugins/flytekit-kf-pytorch/tests/test_shared.py @@ -0,0 +1,109 @@ +"""Test functionality that is shared between the pytorch and pytorch-elastic tasks.""" + +from contextlib import nullcontext +from typing import Union + +import pytest +from flytekitplugins.kfpytorch.task import Elastic, PyTorch +from kubernetes.client import V1Container, V1EmptyDirVolumeSource, V1PodSpec, V1Volume, V1VolumeMount + +from flytekit import PodTemplate, task + + +@pytest.mark.parametrize( + "task_config, pod_template, needs_shm_volume, raises", + [ + # Test that by default shared memory volume is added + (PyTorch(num_workers=3), None, True, False), + (Elastic(nnodes=2, increase_shared_mem=True), None, True, False), + # Test disabling shared memory volume + (PyTorch(num_workers=3, increase_shared_mem=False), None, False, False), + (Elastic(nnodes=2, increase_shared_mem=False), None, False, False), + # Test that explictly passed pod template does not break adding shm volume + (Elastic(nnodes=2, increase_shared_mem=True), PodTemplate(), True, False), + # Test that pod template with container does not break adding shm volume + ( + Elastic(nnodes=2), + PodTemplate( + pod_spec=V1PodSpec(containers=[V1Container(name="primary")]), + ), + True, + False, + ), + # Test that pod template with volume/volume mount does not break adding shm volume + ( + Elastic(nnodes=2), + PodTemplate( + pod_spec=V1PodSpec( + containers=[ + V1Container(name="primary", volume_mounts=[V1VolumeMount(name="foo", mount_path="/bar")]) + ], + volumes=[V1Volume(name="foo")], + ), + ), + True, + False, + ), + # Test that pod template with multiple containers raises an error + ( + Elastic(nnodes=2), + PodTemplate( + pod_spec=V1PodSpec( + containers=[ + V1Container(name="primary"), + V1Container(name="secondary"), + ] + ), + ), + True, + True, + ), + ], +) +def test_task_shared_memory( + task_config: Union[Elastic, PyTorch], pod_template: PodTemplate, needs_shm_volume: bool, raises: bool +): + """Test that the task pod template is configured with a shared memory volume if needed.""" + + expected_volume = V1Volume(name="shm", empty_dir=V1EmptyDirVolumeSource(medium="Memory")) + expected_volume_mount = V1VolumeMount(name="shm", mount_path="/dev/shm") + + with pytest.raises(ValueError) if raises else nullcontext(): + + @task( + task_config=task_config, + pod_template=pod_template, + ) + def test_task() -> None: + pass + + if needs_shm_volume: + assert test_task.pod_template is not None + assert test_task.pod_template.pod_spec is not None + assert test_task.pod_template.pod_spec.volumes is not None + assert test_task.pod_template.pod_spec.containers is not None + assert test_task.pod_template.pod_spec.containers[0].volume_mounts is not None + + assert any([v == expected_volume for v in test_task.pod_template.pod_spec.volumes]) + assert any( + [v == expected_volume_mount for v in test_task.pod_template.pod_spec.containers[0].volume_mounts] + ) + + else: + no_pod_template = test_task.pod_template is None + no_pod_spec = no_pod_template or test_task.pod_template.pod_spec is None + no_volumes = no_pod_spec or test_task.pod_template.pod_spec.volumes is None + no_containers = no_pod_spec or len(test_task.pod_template.pod_spec.containers) == 0 + no_volume_mounts = no_containers or test_task.pod_template.pod_spec.containers[0].volume_mounts is None + empty_volume_mounts = ( + no_volume_mounts or len(test_task.pod_template.pod_spec.containers[0].volume_mounts) == 0 + ) + no_shm_volume_condition = no_volumes or not any( + [v == expected_volume for v in test_task.pod_template.pod_spec.volumes] + ) + no_shm_volume_mount_condition = empty_volume_mounts or not any( + [v == expected_volume_mount for v in test_task.pod_template.pod_spec.containers[0].volume_mounts] + ) + + assert no_shm_volume_condition + assert no_shm_volume_mount_condition From 1009f76272b1d8923b5e094153faea13d49746fd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabio=20Gr=C3=A4tz?= Date: Mon, 1 Jul 2024 22:18:59 +0200 Subject: [PATCH 02/11] Set reasonable default timeouts for pytorch elastic task config MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Fabio Grätz --- .../flytekitplugins/kfpytorch/task.py | 6 +++++- plugins/flytekit-kf-pytorch/tests/test_elastic_task.py | 9 +++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py index 3d5ec5a092..bd9b957fca 100644 --- a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py +++ b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py @@ -138,6 +138,10 @@ class Elastic(object): max_restarts (int): Maximum number of worker group restarts before failing. rdzv_configs (Dict[str, Any]): Additional rendezvous configs to pass to torch elastic, e.g. `{"timeout": 1200, "join_timeout": 900}`. See `torch.distributed.launcher.api.LaunchConfig` and `torch.distributed.elastic.rendezvous.dynamic_rendezvous.create_handler`. + + Default timeouts are set to 15 minutes to account for the fact that some workers might start faster than others: Some pods might + be assigned to a running node which might have the image in its cache while other workers might require a node scale up and image pull. + increase_shared_mem (bool): PyTorch uses shared memory to share data between processes. If torch multiprocessing is used (e.g. for multithreaded data loaders) the default shared memory segment size that the container runs with might not be enough and and one might have to increase the shared memory size. This option configures the task's pod template to mount @@ -149,7 +153,7 @@ class Elastic(object): start_method: str = "spawn" monitor_interval: int = 5 max_restarts: int = 0 - rdzv_configs: Dict[str, Any] = field(default_factory=dict) + rdzv_configs: Dict[str, Any] = field(default_factory=lambda: {"timeout": 900, "join_timeout": 900}) increase_shared_mem: bool = True diff --git a/plugins/flytekit-kf-pytorch/tests/test_elastic_task.py b/plugins/flytekit-kf-pytorch/tests/test_elastic_task.py index fd13a39659..01bf143cef 100644 --- a/plugins/flytekit-kf-pytorch/tests/test_elastic_task.py +++ b/plugins/flytekit-kf-pytorch/tests/test_elastic_task.py @@ -187,3 +187,12 @@ def wf(recoverable: bool): else: with pytest.raises(RuntimeError): wf(recoverable=recoverable) + + +def test_default_timeouts(): + """Test that default timeouts are set for the elastic task.""" + @task(task_config=Elastic(nnodes=1)) + def test_task(): + pass + + assert test_task.task_config.rdzv_configs == {"join_timeout": 900, "timeout": 900} \ No newline at end of file From ddc4b6ee4f84850f73f028dcd4a84ebd37031101 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabio=20Gr=C3=A4tz?= Date: Mon, 1 Jul 2024 22:22:08 +0200 Subject: [PATCH 03/11] Lint MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Fabio Grätz --- plugins/flytekit-kf-pytorch/tests/test_shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/flytekit-kf-pytorch/tests/test_shared.py b/plugins/flytekit-kf-pytorch/tests/test_shared.py index 544d50135b..796b341ace 100644 --- a/plugins/flytekit-kf-pytorch/tests/test_shared.py +++ b/plugins/flytekit-kf-pytorch/tests/test_shared.py @@ -19,7 +19,7 @@ # Test disabling shared memory volume (PyTorch(num_workers=3, increase_shared_mem=False), None, False, False), (Elastic(nnodes=2, increase_shared_mem=False), None, False, False), - # Test that explictly passed pod template does not break adding shm volume + # Test that explicitly passed pod template does not break adding shm volume (Elastic(nnodes=2, increase_shared_mem=True), PodTemplate(), True, False), # Test that pod template with container does not break adding shm volume ( From 5e8d4bc19b6b4cd24e69ac606033a01e6fc305e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabio=20Gr=C3=A4tz?= Date: Mon, 1 Jul 2024 22:46:16 +0200 Subject: [PATCH 04/11] Lint MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Fabio Grätz --- .../flytekitplugins/kfpytorch/pod_template.py | 4 ++-- plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py | 1 - plugins/flytekit-kf-pytorch/tests/test_elastic_task.py | 2 +- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/pod_template.py b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/pod_template.py index 46238e1aa9..2e4bb3301f 100644 --- a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/pod_template.py +++ b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/pod_template.py @@ -1,6 +1,6 @@ -from flytekit.core.pod_template import PodTemplate +from kubernetes.client import V1Container, V1EmptyDirVolumeSource, V1PodSpec, V1Volume, V1VolumeMount -from kubernetes.client import V1Container, V1PodSpec, V1Volume, V1VolumeMount, V1EmptyDirVolumeSource +from flytekit.core.pod_template import PodTemplate def add_shared_mem_volume_to_pod_template(pod_template: PodTemplate) -> None: diff --git a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py index bd9b957fca..254ae0018d 100644 --- a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py +++ b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py @@ -24,7 +24,6 @@ from .error_handling import create_recoverable_error_file, is_recoverable_worker_error from .pod_template import add_shared_mem_volume_to_pod_template - cloudpickle = lazy_module("cloudpickle") TORCH_IMPORT_ERROR_MESSAGE = "PyTorch is not installed. Please install `flytekitplugins-kfpytorch['elastic']`." diff --git a/plugins/flytekit-kf-pytorch/tests/test_elastic_task.py b/plugins/flytekit-kf-pytorch/tests/test_elastic_task.py index 01bf143cef..52d267740d 100644 --- a/plugins/flytekit-kf-pytorch/tests/test_elastic_task.py +++ b/plugins/flytekit-kf-pytorch/tests/test_elastic_task.py @@ -195,4 +195,4 @@ def test_default_timeouts(): def test_task(): pass - assert test_task.task_config.rdzv_configs == {"join_timeout": 900, "timeout": 900} \ No newline at end of file + assert test_task.task_config.rdzv_configs == {"join_timeout": 900, "timeout": 900} From 927e9012af7aa9a2a9f6d878a59ed2ac708c4ad2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabio=20Gr=C3=A4tz?= Date: Wed, 3 Jul 2024 08:31:33 +0200 Subject: [PATCH 05/11] Add shm size upper limit to docstring MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Fabio Grätz --- plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py index 254ae0018d..ee04247d92 100644 --- a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py +++ b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py @@ -109,6 +109,7 @@ class PyTorch(object): (e.g. for multithreaded data loaders) the default shared memory segment size that the container runs with might not be enough and and one might have to increase the shared memory size. This option configures the task's pod template to mount an `emptyDir` volume with medium `Memory` to to `/dev/shm`. + The shared memory size upper limit is the sum of the memory limits of the containers in the pod. """ master: Master = field(default_factory=lambda: Master()) @@ -145,6 +146,7 @@ class Elastic(object): (e.g. for multithreaded data loaders) the default shared memory segment size that the container runs with might not be enough and and one might have to increase the shared memory size. This option configures the task's pod template to mount an `emptyDir` volume with medium `Memory` to to `/dev/shm`. + The shared memory size upper limit is the sum of the memory limits of the containers in the pod. """ nnodes: Union[int, str] = 1 From 6c565dc01606ef64e70841cf8da7122835ac20f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabio=20Gr=C3=A4tz?= Date: Wed, 3 Jul 2024 08:43:36 +0200 Subject: [PATCH 06/11] Correct docstring: multi-threading -> multi-processing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Fabio Grätz --- plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py index ee04247d92..1ac31b8e34 100644 --- a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py +++ b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py @@ -106,7 +106,7 @@ class PyTorch(object): run_policy: Configuration for the run policy. num_workers: [DEPRECATED] This argument is deprecated. Use `worker.replicas` instead. increase_shared_mem (bool): PyTorch uses shared memory to share data between processes. If torch multiprocessing is used - (e.g. for multithreaded data loaders) the default shared memory segment size that the container runs with might not be enough + (e.g. for multi-processed data loaders) the default shared memory segment size that the container runs with might not be enough and and one might have to increase the shared memory size. This option configures the task's pod template to mount an `emptyDir` volume with medium `Memory` to to `/dev/shm`. The shared memory size upper limit is the sum of the memory limits of the containers in the pod. @@ -143,7 +143,7 @@ class Elastic(object): be assigned to a running node which might have the image in its cache while other workers might require a node scale up and image pull. increase_shared_mem (bool): PyTorch uses shared memory to share data between processes. If torch multiprocessing is used - (e.g. for multithreaded data loaders) the default shared memory segment size that the container runs with might not be enough + (e.g. for multi-processed data loaders) the default shared memory segment size that the container runs with might not be enough and and one might have to increase the shared memory size. This option configures the task's pod template to mount an `emptyDir` volume with medium `Memory` to to `/dev/shm`. The shared memory size upper limit is the sum of the memory limits of the containers in the pod. From 13a90b96ac154c65463e3ece9320d8103c9c359a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabio=20Gr=C3=A4tz?= Date: Thu, 18 Jul 2024 20:09:56 +0200 Subject: [PATCH 07/11] Add kubernetes dep MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Fabio Grätz --- plugins/flytekit-kf-pytorch/setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/flytekit-kf-pytorch/setup.py b/plugins/flytekit-kf-pytorch/setup.py index cc90e0b299..317ca7b8a0 100644 --- a/plugins/flytekit-kf-pytorch/setup.py +++ b/plugins/flytekit-kf-pytorch/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["cloudpickle", "flyteidl>=1.5.1", "flytekit>=1.6.1"] +plugin_requires = ["cloudpickle", "flyteidl>=1.5.1", "flytekit>=1.6.1", "kubernetes"] __version__ = "0.0.0+develop" From d03b6af5f2a7d195ab4191e6ed07a0e47b1d1aa9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabio=20Gr=C3=A4tz?= Date: Thu, 18 Jul 2024 20:14:27 +0200 Subject: [PATCH 08/11] Refactor the check of num containers as proposed in code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Fabio Grätz --- .../flytekitplugins/kfpytorch/pod_template.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/pod_template.py b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/pod_template.py index 2e4bb3301f..8cc395bb46 100644 --- a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/pod_template.py +++ b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/pod_template.py @@ -19,17 +19,17 @@ def add_shared_mem_volume_to_pod_template(pod_template: PodTemplate) -> None: pod_template.pod_spec.volumes.append(shm_volume) num_containers = len(pod_template.pod_spec.containers) - if num_containers == 0: - pod_template.pod_spec.containers.append(V1Container(name="primary")) - elif num_containers == 1: - pass - else: + + if num_containers >= 2: raise ValueError( "When configuring a pod template with multiple containers, please set `increase_shared_mem=False` " "in the task config and if required mount a volume to increase the shared memory size in the respective " "container yourself." ) + if num_containers != 1: + pod_template.pod_spec.containers.append(V1Container(name="primary")) + if pod_template.pod_spec.containers[0].volume_mounts is None: pod_template.pod_spec.containers[0].volume_mounts = [] From 60c8211a2fd135e89f4cd972d6973eac232d18f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabio=20Gr=C3=A4tz?= Date: Thu, 18 Jul 2024 20:22:56 +0200 Subject: [PATCH 09/11] Lint MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Fabio Grätz --- plugins/flytekit-kf-pytorch/tests/test_elastic_task.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/flytekit-kf-pytorch/tests/test_elastic_task.py b/plugins/flytekit-kf-pytorch/tests/test_elastic_task.py index e6aa1815a1..b56fc0aa08 100644 --- a/plugins/flytekit-kf-pytorch/tests/test_elastic_task.py +++ b/plugins/flytekit-kf-pytorch/tests/test_elastic_task.py @@ -204,7 +204,7 @@ def test_task(): assert test_task.task_config.rdzv_configs == {"join_timeout": 900, "timeout": 900} - + def test_run_policy() -> None: """Test that run policy is propagated to custom spec.""" From d6941d5655452538f2ee79ac18f2027960c06f3d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabio=20Gr=C3=A4tz?= Date: Fri, 19 Jul 2024 18:16:58 +0200 Subject: [PATCH 10/11] Test that explicitly configured pod template with shm vol is not removed if disable in task config MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Fabio Grätz --- plugins/flytekit-kf-pytorch/tests/test_shared.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/plugins/flytekit-kf-pytorch/tests/test_shared.py b/plugins/flytekit-kf-pytorch/tests/test_shared.py index 796b341ace..9853feda50 100644 --- a/plugins/flytekit-kf-pytorch/tests/test_shared.py +++ b/plugins/flytekit-kf-pytorch/tests/test_shared.py @@ -58,6 +58,20 @@ True, True, ), + # Test that explicitly configured pod template with shared memory volume is not removed if `increase_shared_mem=False` + ( + Elastic(nnodes=2, increase_shared_mem=False), + PodTemplate( + pod_spec=V1PodSpec( + containers=[ + V1Container(name="primary", volume_mounts=[V1VolumeMount(name="shm", mount_path="/dev/shm")]), + ], + volumes=[V1Volume(name="shm", empty_dir=V1EmptyDirVolumeSource(medium="Memory"))], + ), + ), + True, + False, + ), ], ) def test_task_shared_memory( @@ -90,6 +104,7 @@ def test_task() -> None: ) else: + # Check that the shared memory volume + volume mount is not added no_pod_template = test_task.pod_template is None no_pod_spec = no_pod_template or test_task.pod_template.pod_spec is None no_volumes = no_pod_spec or test_task.pod_template.pod_spec.volumes is None From cadd905db98a6550b3b4b27cfaea7e96d40fb0b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabio=20Gr=C3=A4tz?= Date: Fri, 19 Jul 2024 18:27:33 +0200 Subject: [PATCH 11/11] Raise if the user explicitly configured shm vol mount and still sets task config to add it MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Fabio Grätz --- .../flytekitplugins/kfpytorch/pod_template.py | 13 ++++++++++++- plugins/flytekit-kf-pytorch/tests/test_shared.py | 14 ++++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/pod_template.py b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/pod_template.py index 8cc395bb46..8d8567d3e7 100644 --- a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/pod_template.py +++ b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/pod_template.py @@ -5,8 +5,9 @@ def add_shared_mem_volume_to_pod_template(pod_template: PodTemplate) -> None: """Add shared memory volume and volume mount to the pod template.""" + mount_path = "/dev/shm" shm_volume = V1Volume(name="shm", empty_dir=V1EmptyDirVolumeSource(medium="Memory")) - shm_volume_mount = V1VolumeMount(name="shm", mount_path="/dev/shm") + shm_volume_mount = V1VolumeMount(name="shm", mount_path=mount_path) if pod_template.pod_spec is None: pod_template.pod_spec = V1PodSpec() @@ -16,6 +17,7 @@ def add_shared_mem_volume_to_pod_template(pod_template: PodTemplate) -> None: if pod_template.pod_spec.volumes is None: pod_template.pod_spec.volumes = [] + pod_template.pod_spec.volumes.append(shm_volume) num_containers = len(pod_template.pod_spec.containers) @@ -33,4 +35,13 @@ def add_shared_mem_volume_to_pod_template(pod_template: PodTemplate) -> None: if pod_template.pod_spec.containers[0].volume_mounts is None: pod_template.pod_spec.containers[0].volume_mounts = [] + has_shared_mem_vol_mount = any( + [v.mount_path == mount_path for v in pod_template.pod_spec.containers[0].volume_mounts] + ) + if has_shared_mem_vol_mount: + raise ValueError( + "A shared memory volume mount is already configured in the pod template. " + "Please remove the volume mount or set `increase_shared_mem=False` in the task config." + ) + pod_template.pod_spec.containers[0].volume_mounts.append(shm_volume_mount) diff --git a/plugins/flytekit-kf-pytorch/tests/test_shared.py b/plugins/flytekit-kf-pytorch/tests/test_shared.py index 9853feda50..b86f9a73d9 100644 --- a/plugins/flytekit-kf-pytorch/tests/test_shared.py +++ b/plugins/flytekit-kf-pytorch/tests/test_shared.py @@ -72,6 +72,20 @@ True, False, ), + # Test that we raise if the user explicitly configured a shared memory volume and still configures the task config to add it + ( + Elastic(nnodes=2, increase_shared_mem=True), + PodTemplate( + pod_spec=V1PodSpec( + containers=[ + V1Container(name="primary", volume_mounts=[V1VolumeMount(name="shm", mount_path="/dev/shm")]), + ], + volumes=[V1Volume(name="shm", empty_dir=V1EmptyDirVolumeSource(medium="Memory"))], + ), + ), + True, + True, + ), ], ) def test_task_shared_memory(