-
Notifications
You must be signed in to change notification settings - Fork 301
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Feat: Improve UX of pytorch-elastic plugin by configuring reasonable …
…defaults (#2543) * Add flag to Elastic and PyTorch task config which configures shared memory volume Signed-off-by: Fabio Grätz <fabiogratz@googlemail.com> * Set reasonable default timeouts for pytorch elastic task config Signed-off-by: Fabio Grätz <fabiogratz@googlemail.com> * Lint Signed-off-by: Fabio Grätz <fabiogratz@googlemail.com> * Lint Signed-off-by: Fabio Grätz <fabiogratz@googlemail.com> * Add shm size upper limit to docstring Signed-off-by: Fabio Grätz <fabiogratz@googlemail.com> * Correct docstring: multi-threading -> multi-processing Signed-off-by: Fabio Grätz <fabiogratz@googlemail.com> * Add kubernetes dep Signed-off-by: Fabio Grätz <fabiogratz@googlemail.com> * Refactor the check of num containers as proposed in code review Signed-off-by: Fabio Grätz <fabiogratz@googlemail.com> * Lint Signed-off-by: Fabio Grätz <fabiogratz@googlemail.com> * Test that explicitly configured pod template with shm vol is not removed if disable in task config Signed-off-by: Fabio Grätz <fabiogratz@googlemail.com> * Raise if the user explicitly configured shm vol mount and still sets task config to add it Signed-off-by: Fabio Grätz <fabiogratz@googlemail.com> --------- Signed-off-by: Fabio Grätz <fabiogratz@googlemail.com> Co-authored-by: Fabio Grätz <fabiogratz@googlemail.com>
- Loading branch information
Showing
5 changed files
with
222 additions
and
2 deletions.
There are no files selected for viewing
47 changes: 47 additions & 0 deletions
47
plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/pod_template.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
from kubernetes.client import V1Container, V1EmptyDirVolumeSource, V1PodSpec, V1Volume, V1VolumeMount | ||
|
||
from flytekit.core.pod_template import PodTemplate | ||
|
||
|
||
def add_shared_mem_volume_to_pod_template(pod_template: PodTemplate) -> None: | ||
"""Add shared memory volume and volume mount to the pod template.""" | ||
mount_path = "/dev/shm" | ||
shm_volume = V1Volume(name="shm", empty_dir=V1EmptyDirVolumeSource(medium="Memory")) | ||
shm_volume_mount = V1VolumeMount(name="shm", mount_path=mount_path) | ||
|
||
if pod_template.pod_spec is None: | ||
pod_template.pod_spec = V1PodSpec() | ||
|
||
if pod_template.pod_spec.containers is None: | ||
pod_template.pod_spec.containers = [] | ||
|
||
if pod_template.pod_spec.volumes is None: | ||
pod_template.pod_spec.volumes = [] | ||
|
||
pod_template.pod_spec.volumes.append(shm_volume) | ||
|
||
num_containers = len(pod_template.pod_spec.containers) | ||
|
||
if num_containers >= 2: | ||
raise ValueError( | ||
"When configuring a pod template with multiple containers, please set `increase_shared_mem=False` " | ||
"in the task config and if required mount a volume to increase the shared memory size in the respective " | ||
"container yourself." | ||
) | ||
|
||
if num_containers != 1: | ||
pod_template.pod_spec.containers.append(V1Container(name="primary")) | ||
|
||
if pod_template.pod_spec.containers[0].volume_mounts is None: | ||
pod_template.pod_spec.containers[0].volume_mounts = [] | ||
|
||
has_shared_mem_vol_mount = any( | ||
[v.mount_path == mount_path for v in pod_template.pod_spec.containers[0].volume_mounts] | ||
) | ||
if has_shared_mem_vol_mount: | ||
raise ValueError( | ||
"A shared memory volume mount is already configured in the pod template. " | ||
"Please remove the volume mount or set `increase_shared_mem=False` in the task config." | ||
) | ||
|
||
pod_template.pod_spec.containers[0].volume_mounts.append(shm_volume_mount) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
"""Test functionality that is shared between the pytorch and pytorch-elastic tasks.""" | ||
|
||
from contextlib import nullcontext | ||
from typing import Union | ||
|
||
import pytest | ||
from flytekitplugins.kfpytorch.task import Elastic, PyTorch | ||
from kubernetes.client import V1Container, V1EmptyDirVolumeSource, V1PodSpec, V1Volume, V1VolumeMount | ||
|
||
from flytekit import PodTemplate, task | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"task_config, pod_template, needs_shm_volume, raises", | ||
[ | ||
# Test that by default shared memory volume is added | ||
(PyTorch(num_workers=3), None, True, False), | ||
(Elastic(nnodes=2, increase_shared_mem=True), None, True, False), | ||
# Test disabling shared memory volume | ||
(PyTorch(num_workers=3, increase_shared_mem=False), None, False, False), | ||
(Elastic(nnodes=2, increase_shared_mem=False), None, False, False), | ||
# Test that explicitly passed pod template does not break adding shm volume | ||
(Elastic(nnodes=2, increase_shared_mem=True), PodTemplate(), True, False), | ||
# Test that pod template with container does not break adding shm volume | ||
( | ||
Elastic(nnodes=2), | ||
PodTemplate( | ||
pod_spec=V1PodSpec(containers=[V1Container(name="primary")]), | ||
), | ||
True, | ||
False, | ||
), | ||
# Test that pod template with volume/volume mount does not break adding shm volume | ||
( | ||
Elastic(nnodes=2), | ||
PodTemplate( | ||
pod_spec=V1PodSpec( | ||
containers=[ | ||
V1Container(name="primary", volume_mounts=[V1VolumeMount(name="foo", mount_path="/bar")]) | ||
], | ||
volumes=[V1Volume(name="foo")], | ||
), | ||
), | ||
True, | ||
False, | ||
), | ||
# Test that pod template with multiple containers raises an error | ||
( | ||
Elastic(nnodes=2), | ||
PodTemplate( | ||
pod_spec=V1PodSpec( | ||
containers=[ | ||
V1Container(name="primary"), | ||
V1Container(name="secondary"), | ||
] | ||
), | ||
), | ||
True, | ||
True, | ||
), | ||
# Test that explicitly configured pod template with shared memory volume is not removed if `increase_shared_mem=False` | ||
( | ||
Elastic(nnodes=2, increase_shared_mem=False), | ||
PodTemplate( | ||
pod_spec=V1PodSpec( | ||
containers=[ | ||
V1Container(name="primary", volume_mounts=[V1VolumeMount(name="shm", mount_path="/dev/shm")]), | ||
], | ||
volumes=[V1Volume(name="shm", empty_dir=V1EmptyDirVolumeSource(medium="Memory"))], | ||
), | ||
), | ||
True, | ||
False, | ||
), | ||
# Test that we raise if the user explicitly configured a shared memory volume and still configures the task config to add it | ||
( | ||
Elastic(nnodes=2, increase_shared_mem=True), | ||
PodTemplate( | ||
pod_spec=V1PodSpec( | ||
containers=[ | ||
V1Container(name="primary", volume_mounts=[V1VolumeMount(name="shm", mount_path="/dev/shm")]), | ||
], | ||
volumes=[V1Volume(name="shm", empty_dir=V1EmptyDirVolumeSource(medium="Memory"))], | ||
), | ||
), | ||
True, | ||
True, | ||
), | ||
], | ||
) | ||
def test_task_shared_memory( | ||
task_config: Union[Elastic, PyTorch], pod_template: PodTemplate, needs_shm_volume: bool, raises: bool | ||
): | ||
"""Test that the task pod template is configured with a shared memory volume if needed.""" | ||
|
||
expected_volume = V1Volume(name="shm", empty_dir=V1EmptyDirVolumeSource(medium="Memory")) | ||
expected_volume_mount = V1VolumeMount(name="shm", mount_path="/dev/shm") | ||
|
||
with pytest.raises(ValueError) if raises else nullcontext(): | ||
|
||
@task( | ||
task_config=task_config, | ||
pod_template=pod_template, | ||
) | ||
def test_task() -> None: | ||
pass | ||
|
||
if needs_shm_volume: | ||
assert test_task.pod_template is not None | ||
assert test_task.pod_template.pod_spec is not None | ||
assert test_task.pod_template.pod_spec.volumes is not None | ||
assert test_task.pod_template.pod_spec.containers is not None | ||
assert test_task.pod_template.pod_spec.containers[0].volume_mounts is not None | ||
|
||
assert any([v == expected_volume for v in test_task.pod_template.pod_spec.volumes]) | ||
assert any( | ||
[v == expected_volume_mount for v in test_task.pod_template.pod_spec.containers[0].volume_mounts] | ||
) | ||
|
||
else: | ||
# Check that the shared memory volume + volume mount is not added | ||
no_pod_template = test_task.pod_template is None | ||
no_pod_spec = no_pod_template or test_task.pod_template.pod_spec is None | ||
no_volumes = no_pod_spec or test_task.pod_template.pod_spec.volumes is None | ||
no_containers = no_pod_spec or len(test_task.pod_template.pod_spec.containers) == 0 | ||
no_volume_mounts = no_containers or test_task.pod_template.pod_spec.containers[0].volume_mounts is None | ||
empty_volume_mounts = ( | ||
no_volume_mounts or len(test_task.pod_template.pod_spec.containers[0].volume_mounts) == 0 | ||
) | ||
no_shm_volume_condition = no_volumes or not any( | ||
[v == expected_volume for v in test_task.pod_template.pod_spec.volumes] | ||
) | ||
no_shm_volume_mount_condition = empty_volume_mounts or not any( | ||
[v == expected_volume_mount for v in test_task.pod_template.pod_spec.containers[0].volume_mounts] | ||
) | ||
|
||
assert no_shm_volume_condition | ||
assert no_shm_volume_mount_condition |