diff --git a/docs/source/ext/compatibility.py b/docs/source/ext/compatibility.py index ea7e90a6a..b7ef037a1 100644 --- a/docs/source/ext/compatibility.py +++ b/docs/source/ext/compatibility.py @@ -17,6 +17,7 @@ "cancel": "Cancel Job", "describe": "Describe Job", "workspaces": "Workspaces / Patching", + "mounts": "Mounts", }, } diff --git a/docs/source/specs.rst b/docs/source/specs.rst index 623a43965..b266aa7b1 100644 --- a/docs/source/specs.rst +++ b/docs/source/specs.rst @@ -63,6 +63,15 @@ Run Status .. autoclass:: ReplicaState :members: + +Mounts +-------- + +.. autoclass:: BindMount + :members: + +.. autofunction:: parse_mounts + Component Linter ----------------- .. automodule:: torchx.specs.file_linter diff --git a/torchx/components/dist.py b/torchx/components/dist.py index dd8e9451d..5133658db 100644 --- a/torchx/components/dist.py +++ b/torchx/components/dist.py @@ -124,7 +124,7 @@ import os import shlex from pathlib import Path -from typing import Dict, Iterable, Optional +from typing import Dict, Iterable, Optional, List import torchx import torchx.specs as specs @@ -146,6 +146,7 @@ def ddp( max_retries: int = 0, rdzv_backend: str = "c10d", rdzv_endpoint: Optional[str] = None, + mounts: Optional[List[str]] = None, ) -> specs.AppDef: """ Distributed data parallel style application (one role, multi-replica). @@ -171,6 +172,7 @@ def ddp( max_retries: the number of scheduler retries allowed rdzv_backend: rendezvous backend (only matters when nnodes > 1) rdzv_endpoint: rendezvous server endpoint (only matters when nnodes > 1), defaults to rank0 host for schedulers that support it + mounts: the list of mounts to bind mount into the worker environment/container (ex. type=bind,src=/host,dst=/job[,readonly]) """ if (script is None) == (m is None): @@ -244,6 +246,7 @@ def ddp( "c10d": 29500, }, max_retries=max_retries, + mounts=specs.parse_mounts(mounts) if mounts else [], ) ], ) diff --git a/torchx/components/test/dist_test.py b/torchx/components/test/dist_test.py index 0ae6e907d..1d5e7d493 100644 --- a/torchx/components/test/dist_test.py +++ b/torchx/components/test/dist_test.py @@ -11,3 +11,9 @@ class DistributedComponentTest(ComponentTestCase): def test_ddp(self) -> None: self.validate(dist, "ddp") + + def test_ddp_mounts(self) -> None: + app = dist.ddp( + script="foo.py", mounts=["type=bind", "src=/dst", "dst=/dst", "readonly"] + ) + self.assertEqual(len(app.roles[0].mounts), 1) diff --git a/torchx/schedulers/aws_batch_scheduler.py b/torchx/schedulers/aws_batch_scheduler.py index a6b3cdc60..bf46f9eee 100644 --- a/torchx/schedulers/aws_batch_scheduler.py +++ b/torchx/schedulers/aws_batch_scheduler.py @@ -91,6 +91,26 @@ def _role_to_node_properties(idx: int, role: Role) -> Dict[str, object]: if resource.gpu > 0: reqs.append({"type": "GPU", "value": str(resource.gpu)}) + mount_points = [] + volumes = [] + for i, mount in enumerate(role.mounts): + name = f"mount_{i}" + volumes.append( + { + "name": name, + "host": { + "sourcePath": mount.src_path, + }, + } + ) + mount_points.append( + { + "containerPath": mount.dst_path, + "readOnly": mount.read_only, + "sourceVolume": name, + } + ) + container = { "command": [role.entrypoint] + role.args, "image": role.image, @@ -99,6 +119,8 @@ def _role_to_node_properties(idx: int, role: Role) -> Dict[str, object]: "logConfiguration": { "logDriver": "awslogs", }, + "mountPoints": mount_points, + "volumes": volumes, } return { @@ -165,7 +187,8 @@ class AWSBatchScheduler(Scheduler, DockerWorkspace): describe: | Partial support. AWSBatchScheduler will return job and replica status but does not provide the complete original AppSpec. - workspaces: false + workspaces: true + mounts: true """ def __init__( diff --git a/torchx/schedulers/docker_scheduler.py b/torchx/schedulers/docker_scheduler.py index 9240d4189..c15b3ebbb 100644 --- a/torchx/schedulers/docker_scheduler.py +++ b/torchx/schedulers/docker_scheduler.py @@ -115,6 +115,7 @@ class DockerScheduler(Scheduler, DockerWorkspace): Partial support. DockerScheduler will return job and replica status but does not provide the complete original AppSpec. workspaces: true + mounts: true """ def __init__(self, session_name: str) -> None: @@ -171,7 +172,7 @@ def schedule(self, dryrun_info: AppDryRunInfo[DockerJob]) -> str: def _submit_dryrun( self, app: AppDef, cfg: Mapping[str, CfgVal] ) -> AppDryRunInfo[DockerJob]: - from docker.types import DeviceRequest + from docker.types import DeviceRequest, Mount default_env = {} copy_env = cfg.get("copy_env") @@ -189,6 +190,17 @@ def _submit_dryrun( req = DockerJob(app_id=app_id, containers=[]) rank0_name = f"{app_id}-{app.roles[0].name}-0" for role in app.roles: + mounts = [] + for mount in role.mounts: + mounts.append( + Mount( + target=mount.dst_path, + source=mount.src_path, + read_only=mount.read_only, + type="bind", + ) + ) + for replica_id in range(role.num_replicas): values = macros.Values( img_root="", @@ -220,6 +232,7 @@ def _submit_dryrun( }, "hostname": name, "network": NETWORK, + "mounts": mounts, }, ) if replica_role.max_retries > 0: diff --git a/torchx/schedulers/kubernetes_scheduler.py b/torchx/schedulers/kubernetes_scheduler.py index 363041706..91878d321 100644 --- a/torchx/schedulers/kubernetes_scheduler.py +++ b/torchx/schedulers/kubernetes_scheduler.py @@ -166,6 +166,9 @@ def role_to_pod(name: str, role: Role, service_account: Optional[str]) -> "V1Pod V1ResourceRequirements, V1ContainerPort, V1ObjectMeta, + V1VolumeMount, + V1Volume, + V1HostPathVolumeSource, ) requests = {} @@ -183,6 +186,26 @@ def role_to_pod(name: str, role: Role, service_account: Optional[str]) -> "V1Pod requests=requests, ) + volumes = [] + volume_mounts = [] + for i, mount in enumerate(role.mounts): + mount_name = f"mount-{i}" + volumes.append( + V1Volume( + name=mount_name, + host_path=V1HostPathVolumeSource( + path=mount.src_path, + ), + ) + ) + volume_mounts.append( + V1VolumeMount( + name=mount_name, + mount_path=mount.dst_path, + read_only=mount.read_only, + ) + ) + container = V1Container( command=[role.entrypoint] + role.args, image=role.image, @@ -202,12 +225,15 @@ def role_to_pod(name: str, role: Role, service_account: Optional[str]) -> "V1Pod ) for name, port in role.port_map.items() ], + volume_mounts=volume_mounts, ) + return V1Pod( spec=V1PodSpec( containers=[container], restart_policy="Never", service_account_name=service_account, + volumes=volumes, ), metadata=V1ObjectMeta( annotations={ @@ -360,6 +386,7 @@ class KubernetesScheduler(Scheduler, DockerWorkspace): Partial support. KubernetesScheduler will return job and replica status but does not provide the complete original AppSpec. workspaces: true + mounts: true """ def __init__( diff --git a/torchx/schedulers/local_scheduler.py b/torchx/schedulers/local_scheduler.py index bab42a865..b95918090 100644 --- a/torchx/schedulers/local_scheduler.py +++ b/torchx/schedulers/local_scheduler.py @@ -551,6 +551,7 @@ class LocalScheduler(Scheduler): workspaces: | Partial support. LocalScheduler runs the app from a local directory but does not support programmatic workspaces. + mounts: false """ def __init__( diff --git a/torchx/schedulers/ray_scheduler.py b/torchx/schedulers/ray_scheduler.py index 3800dbe49..f4f490910 100644 --- a/torchx/schedulers/ray_scheduler.py +++ b/torchx/schedulers/ray_scheduler.py @@ -108,6 +108,28 @@ class RayJob: actors: List[RayActor] = field(default_factory=list) class RayScheduler(Scheduler): + """ + **Config Options** + + .. runopts:: + class: torchx.schedulers.ray_scheduler.RayScheduler + + **Compatibility** + + .. compatibility:: + type: scheduler + features: + cancel: true + logs: true + distributed: true + describe: | + Partial support. RayScheduler will return job status but + does not provide the complete original AppSpec. + workspaces: false + mounts: false + + """ + def __init__(self, session_name: str) -> None: super().__init__("ray", session_name) diff --git a/torchx/schedulers/slurm_scheduler.py b/torchx/schedulers/slurm_scheduler.py index 5d24ae528..6e24353a4 100644 --- a/torchx/schedulers/slurm_scheduler.py +++ b/torchx/schedulers/slurm_scheduler.py @@ -272,6 +272,7 @@ class SlurmScheduler(Scheduler, DirWorkspace): workspaces: | If ``job_dir`` is specified the DirWorkspace will create a new isolated directory with a snapshot of the workspace. + mounts: false """ def __init__(self, session_name: str) -> None: diff --git a/torchx/schedulers/test/aws_batch_scheduler_test.py b/torchx/schedulers/test/aws_batch_scheduler_test.py index a9b9518cd..11765cb68 100644 --- a/torchx/schedulers/test/aws_batch_scheduler_test.py +++ b/torchx/schedulers/test/aws_batch_scheduler_test.py @@ -39,6 +39,9 @@ def _test_app() -> specs.AppDef: port_map={"foo": 1234}, num_replicas=2, max_retries=3, + mounts=[ + specs.BindMount(src_path="/src", dst_path="/dst", read_only=True), + ], ) return specs.AppDef("test", roles=[trainer_role]) @@ -109,6 +112,21 @@ def test_submit_dryrun(self) -> None: {"type": "GPU", "value": "4"}, ], "logConfiguration": {"logDriver": "awslogs"}, + "mountPoints": [ + { + "containerPath": "/dst", + "readOnly": True, + "sourceVolume": "mount_0", + } + ], + "volumes": [ + { + "name": "mount_0", + "host": { + "sourcePath": "/src", + }, + } + ], }, }, { @@ -136,6 +154,21 @@ def test_submit_dryrun(self) -> None: {"type": "GPU", "value": "4"}, ], "logConfiguration": {"logDriver": "awslogs"}, + "mountPoints": [ + { + "containerPath": "/dst", + "readOnly": True, + "sourceVolume": "mount_0", + } + ], + "volumes": [ + { + "name": "mount_0", + "host": { + "sourcePath": "/src", + }, + } + ], }, }, ], diff --git a/torchx/schedulers/test/docker_scheduler_test.py b/torchx/schedulers/test/docker_scheduler_test.py index c94c2e151..4a0fc9871 100644 --- a/torchx/schedulers/test/docker_scheduler_test.py +++ b/torchx/schedulers/test/docker_scheduler_test.py @@ -11,7 +11,7 @@ from unittest.mock import patch import fsspec -from docker.types import DeviceRequest +from docker.types import DeviceRequest, Mount from torchx import specs from torchx.components.dist import ddp from torchx.schedulers.api import Stream @@ -48,6 +48,9 @@ def _test_app() -> specs.AppDef: port_map={"foo": 1234}, num_replicas=1, max_retries=3, + mounts=[ + specs.BindMount(src_path="/tmp", dst_path="/tmp", read_only=True), + ], ) return specs.AppDef("test", roles=[trainer_role]) @@ -105,6 +108,14 @@ def test_submit_dryrun(self) -> None: "MaximumRetryCount": 3, }, "network": "torchx", + "mounts": [ + Mount( + target="/tmp", + source="/tmp", + read_only=True, + type="bind", + ) + ], }, ) ], diff --git a/torchx/schedulers/test/kubernetes_scheduler_test.py b/torchx/schedulers/test/kubernetes_scheduler_test.py index 3df1593ae..2b59cbc95 100644 --- a/torchx/schedulers/test/kubernetes_scheduler_test.py +++ b/torchx/schedulers/test/kubernetes_scheduler_test.py @@ -51,6 +51,9 @@ def _test_app(num_replicas: int = 1) -> specs.AppDef: port_map={"foo": 1234}, num_replicas=num_replicas, max_retries=3, + mounts=[ + specs.BindMount(src_path="/src", dst_path="/dst", read_only=True), + ], ) return specs.AppDef("test", roles=[trainer_role]) @@ -112,6 +115,9 @@ def test_role_to_pod(self) -> None: V1ResourceRequirements, V1ContainerPort, V1ObjectMeta, + V1Volume, + V1VolumeMount, + V1HostPathVolumeSource, ) app = _test_app() @@ -141,12 +147,27 @@ def test_role_to_pod(self) -> None: env=[V1EnvVar(name="FOO", value="bar")], resources=resources, ports=[V1ContainerPort(name="foo", container_port=1234)], + volume_mounts=[ + V1VolumeMount( + name="mount-0", + mount_path="/dst", + read_only=True, + ) + ], ) want = V1Pod( spec=V1PodSpec( containers=[container], restart_policy="Never", service_account_name="srvacc", + volumes=[ + V1Volume( + name="mount-0", + host_path=V1HostPathVolumeSource( + path="/src", + ), + ), + ], ), metadata=V1ObjectMeta( annotations={ @@ -156,6 +177,8 @@ def test_role_to_pod(self) -> None: ), ) + print(want) + self.assertEqual( pod, want, @@ -184,6 +207,8 @@ def test_submit_dryrun(self) -> None: resource = str(info.request) + print(resource) + self.assertEqual( resource, f"""apiVersion: batch.volcano.sh/v1alpha1 @@ -246,7 +271,15 @@ def test_submit_dryrun(self) -> None: cpu: 2000m memory: 3000M nvidia.com/gpu: '4' + volumeMounts: + - mountPath: /dst + name: mount-0 + readOnly: true restartPolicy: Never + volumes: + - hostPath: + path: /src + name: mount-0 """, ) diff --git a/torchx/specs/__init__.py b/torchx/specs/__init__.py index d22fe294c..884dfc7c8 100644 --- a/torchx/specs/__init__.py +++ b/torchx/specs/__init__.py @@ -26,6 +26,7 @@ AppHandle, AppState, AppStatus, + BindMount, CfgVal, InvalidRunConfigException, MalformedAppHandleException, @@ -46,6 +47,7 @@ parse_app_handle, runopt, runopts, + parse_mounts, ) diff --git a/torchx/specs/api.py b/torchx/specs/api.py index 1f7e20ab0..4cb160bf2 100644 --- a/torchx/specs/api.py +++ b/torchx/specs/api.py @@ -196,6 +196,24 @@ class RetryPolicy(str, Enum): APPLICATION = "APPLICATION" +@dataclass +class BindMount: + """ + Defines a bind mount to `mount --bind` a host path into the worker + environment. See scheduler documentation on how bind mounts operate for each + scheduler. + + Args: + src_path: the path on the host + dst_path: the path in the worker environment/container + read_only: whether the bind should be read only + """ + + src_path: str + dst_path: str + read_only: bool = False + + @dataclass class Role: """ @@ -242,6 +260,7 @@ class Role: e.g. "tensorboard": 9090 metadata: Free form information that is associated with the role, for example scheduler specific data. The key should follow the pattern: ``$scheduler.$key`` + mounts: a list of mounts on the machine """ name: str @@ -256,6 +275,7 @@ class Role: resource: Resource = NULL_RESOURCE port_map: Dict[str, int] = field(default_factory=dict) metadata: Dict[str, Any] = field(default_factory=dict) + mounts: List[BindMount] = field(default_factory=list) def pre_proc( self, @@ -849,3 +869,57 @@ def from_function( cmpnt_defaults, ) return cmpnt_fn(*function_args, *var_arg, **kwargs) + + +_MOUNT_OPT_MAP: Mapping[str, str] = { + "type": "type", + "destination": "dst", + "dst": "dst", + "target": "dst", + "read_only": "readonly", + "readonly": "readonly", + "source": "src", + "src": "src", +} + + +def parse_mounts(opts: List[str]) -> List[BindMount]: + """ + parse_mounts parses a list of options into typed mounts following a similar + format to Dockers bind mount. + + Multiple mounts can be specified in the same list. ``type`` must be + specified first in each. + + Ex: + type=bind,src=/host,dst=/container,readonly,[type=bind,src=...,dst=...] + + Supported types: + BindMount: type=bind,src=,dst=[,readonly] + """ + mount_opts = [] + cur = {} + for opt in opts: + key, _, val = opt.partition("=") + if key not in _MOUNT_OPT_MAP: + raise KeyError( + f"unknown mount option {key}, must be one of {list(_MOUNT_OPT_MAP.keys())}" + ) + key = _MOUNT_OPT_MAP[key] + if key == "type": + cur = {} + mount_opts.append(cur) + elif len(mount_opts) == 0: + raise KeyError("type must be specified first") + cur[key] = val + + mounts = [] + for opts in mount_opts: + typ = opts.get("type") + assert typ == "bind", "only bind mounts are currently supported" + mounts.append( + BindMount( + src_path=opts["src"], dst_path=opts["dst"], read_only="readonly" in opts + ) + ) + return mounts diff --git a/torchx/specs/test/api_test.py b/torchx/specs/test/api_test.py index c9660e5cf..67b899512 100644 --- a/torchx/specs/test/api_test.py +++ b/torchx/specs/test/api_test.py @@ -36,6 +36,8 @@ make_app_handle, parse_app_handle, runopts, + parse_mounts, + BindMount, ) @@ -786,3 +788,45 @@ def test_argparster_complex_fn_partial(self) -> None: self.assertTupleEqual( (" ", None), none_throws(self._get_argument_help(parser, "roles_args")) ) + + +class MountsTest(unittest.TestCase): + def test_empty(self) -> None: + self.assertEqual(parse_mounts([]), []) + + def test_bindmount(self) -> None: + self.assertEqual( + parse_mounts( + [ + "type=bind", + "src=foo", + "dst=dst", + "type=bind", + "source=foo1", + "readonly", + "target=dst1", + "type=bind", + "destination=dst2", + "source=foo2", + ] + ), + [ + BindMount(src_path="foo", dst_path="dst"), + BindMount(src_path="foo1", dst_path="dst1", read_only=True), + BindMount(src_path="foo2", dst_path="dst2"), + ], + ) + + def test_invalid(self) -> None: + with self.assertRaisesRegex(KeyError, "type must be specified first"): + parse_mounts(["src=foo"]) + with self.assertRaisesRegex( + KeyError, "unknown mount option blah, must be one of.*type" + ): + parse_mounts(["blah=foo"]) + with self.assertRaisesRegex(KeyError, "src"): + parse_mounts(["type=bind"]) + with self.assertRaisesRegex( + AssertionError, "only bind mounts are currently supported" + ): + parse_mounts(["type=foo"])