From cbf9b7f198bf2f1e255e0dda5c47324b63cc8bd3 Mon Sep 17 00:00:00 2001 From: Stephen Knox Date: Thu, 29 Feb 2024 16:59:01 +0000 Subject: [PATCH] feat: add ephemeral_storage parameter for Batch (Fargate-only) (#1739) * feat: add ephemeral storage parameter for Batch Fargate Fix some documentation for other params * Further clarify docs * fix: black reformat --------- Co-authored-by: Stephen Knox --- metaflow/plugins/aws/batch/batch.py | 4 ++++ metaflow/plugins/aws/batch/batch_cli.py | 8 ++++++++ metaflow/plugins/aws/batch/batch_client.py | 15 +++++++++++++++ metaflow/plugins/aws/batch/batch_decorator.py | 7 ++++++- .../plugins/aws/step_functions/step_functions.py | 1 + 5 files changed, 34 insertions(+), 1 deletion(-) diff --git a/metaflow/plugins/aws/batch/batch.py b/metaflow/plugins/aws/batch/batch.py index adadbd57870..bf7188501e4 100644 --- a/metaflow/plugins/aws/batch/batch.py +++ b/metaflow/plugins/aws/batch/batch.py @@ -192,6 +192,7 @@ def create_job( tmpfs_size=None, tmpfs_path=None, num_parallel=0, + ephemeral_storage=None, ): job_name = self._job_name( attrs.get("metaflow.user"), @@ -240,6 +241,7 @@ def create_job( tmpfs_size=tmpfs_size, tmpfs_path=tmpfs_path, num_parallel=num_parallel, + ephemeral_storage=ephemeral_storage, ) .task_id(attrs.get("metaflow.task_id")) .environment_variable("AWS_DEFAULT_REGION", self._client.region()) @@ -353,6 +355,7 @@ def launch_job( num_parallel=0, env={}, attrs={}, + ephemeral_storage=None, ): if queue is None: queue = next(self._client.active_job_queues(), None) @@ -390,6 +393,7 @@ def launch_job( tmpfs_size=tmpfs_size, tmpfs_path=tmpfs_path, num_parallel=num_parallel, + ephemeral_storage=ephemeral_storage, ) self.num_parallel = num_parallel self.job = job.execute() diff --git a/metaflow/plugins/aws/batch/batch_cli.py b/metaflow/plugins/aws/batch/batch_cli.py index 67863f17111..805e3ad8614 100644 --- a/metaflow/plugins/aws/batch/batch_cli.py +++ b/metaflow/plugins/aws/batch/batch_cli.py @@ -154,6 +154,12 @@ def kill(ctx, run_id, user, my_runs): @click.option("--ubf-context", default=None, type=click.Choice([None, "ubf_control"])) @click.option("--host-volumes", multiple=True) @click.option("--efs-volumes", multiple=True) +@click.option( + "--ephemeral-storage", + default=None, + type=int, + help="Ephemeral storage (for AWS Batch only)", +) @click.option( "--num-parallel", default=0, @@ -186,6 +192,7 @@ def step( tmpfs_path=None, host_volumes=None, efs_volumes=None, + ephemeral_storage=None, num_parallel=None, **kwargs ): @@ -317,6 +324,7 @@ def _sync_metadata(): tmpfs_tempdir=tmpfs_tempdir, tmpfs_size=tmpfs_size, tmpfs_path=tmpfs_path, + ephemeral_storage=ephemeral_storage, num_parallel=num_parallel, ) except Exception as e: diff --git a/metaflow/plugins/aws/batch/batch_client.py b/metaflow/plugins/aws/batch/batch_client.py index cb0d3207bf2..cfd38b49b66 100644 --- a/metaflow/plugins/aws/batch/batch_client.py +++ b/metaflow/plugins/aws/batch/batch_client.py @@ -158,6 +158,7 @@ def _register_job_definition( tmpfs_size, tmpfs_path, num_parallel, + ephemeral_storage, ): # identify platform from any compute environment associated with the # queue @@ -210,6 +211,10 @@ def _register_job_definition( job_definition["containerProperties"]["networkConfiguration"] = { "assignPublicIp": "ENABLED" } + if ephemeral_storage: + job_definition["containerProperties"]["ephemeralStorage"] = { + "sizeInGiB": ephemeral_storage + } if platform == "EC2" or platform == "SPOT": if "linuxParameters" not in job_definition["containerProperties"]: @@ -254,6 +259,10 @@ def _register_job_definition( job_definition["containerProperties"]["linuxParameters"][ "maxSwap" ] = int(max_swap) + if ephemeral_storage: + raise BatchJobException( + "The ephemeral_storage parameter is only available for FARGATE compute environments" + ) if inferentia: if not (isinstance(inferentia, (int, unicode, basestring))): @@ -315,6 +324,10 @@ def _register_job_definition( {"sourceVolume": name, "containerPath": container_path} ) + if use_tmpfs and (platform == "FARGATE" or platform == "FARGATE_SPOT"): + raise BatchJobException( + "tmpfs is not available for Fargate compute resources" + ) if use_tmpfs or (tmpfs_size and not use_tmpfs): if tmpfs_size: if not (isinstance(tmpfs_size, (int, unicode, basestring))): @@ -442,6 +455,7 @@ def job_def( tmpfs_size, tmpfs_path, num_parallel, + ephemeral_storage, ): self.payload["jobDefinition"] = self._register_job_definition( image, @@ -461,6 +475,7 @@ def job_def( tmpfs_size, tmpfs_path, num_parallel, + ephemeral_storage, ) return self diff --git a/metaflow/plugins/aws/batch/batch_decorator.py b/metaflow/plugins/aws/batch/batch_decorator.py index c269094c897..dd0f7bd97b4 100644 --- a/metaflow/plugins/aws/batch/batch_decorator.py +++ b/metaflow/plugins/aws/batch/batch_decorator.py @@ -72,7 +72,8 @@ class BatchDecorator(StepDecorator): necessary. A swappiness value of 100 causes pages to be swapped very aggressively. Accepted values are whole numbers between 0 and 100. use_tmpfs : bool, default False - This enables an explicit tmpfs mount for this step. + This enables an explicit tmpfs mount for this step. Note that tmpfs is + not available on Fargate compute environments tmpfs_tempdir : bool, default True sets METAFLOW_TEMPDIR to tmpfs_path if set for this step. tmpfs_size : int, optional, default None @@ -85,6 +86,9 @@ class BatchDecorator(StepDecorator): Number of Inferentia chips required for this step. efa : int, default 0 Number of elastic fabric adapter network devices to attach to container + ephemeral_storage: int, default None + The total amount, in GiB, of ephemeral storage to set for the task (21-200) + This is only relevant for Fargate compute environments """ name = "batch" @@ -107,6 +111,7 @@ class BatchDecorator(StepDecorator): "tmpfs_tempdir": True, "tmpfs_size": None, "tmpfs_path": "/metaflow_temp", + "ephemeral_storage": None, } resource_defaults = { "cpu": "1", diff --git a/metaflow/plugins/aws/step_functions/step_functions.py b/metaflow/plugins/aws/step_functions/step_functions.py index 5ad3bfceca3..9b6a63ad0d1 100644 --- a/metaflow/plugins/aws/step_functions/step_functions.py +++ b/metaflow/plugins/aws/step_functions/step_functions.py @@ -839,6 +839,7 @@ def _batch(self, node): attrs=attrs, host_volumes=resources["host_volumes"], efs_volumes=resources["efs_volumes"], + ephemeral_storage=resources["ephemeral_storage"], ) .attempts(total_retries + 1) )