Skip to content

Commit

Permalink
feat: add ephemeral_storage parameter for Batch (Fargate-only) (#1739)
Browse files Browse the repository at this point in the history
* feat: add ephemeral storage parameter for Batch Fargate

Fix some documentation for other params

* Further clarify docs

* fix: black reformat

---------

Co-authored-by: Stephen Knox <stephen.knox@maplecroft.com>
  • Loading branch information
stev-0 and Stephen Knox authored Feb 29, 2024
1 parent d3ec9a5 commit cbf9b7f
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 1 deletion.
4 changes: 4 additions & 0 deletions metaflow/plugins/aws/batch/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ def create_job(
tmpfs_size=None,
tmpfs_path=None,
num_parallel=0,
ephemeral_storage=None,
):
job_name = self._job_name(
attrs.get("metaflow.user"),
Expand Down Expand Up @@ -240,6 +241,7 @@ def create_job(
tmpfs_size=tmpfs_size,
tmpfs_path=tmpfs_path,
num_parallel=num_parallel,
ephemeral_storage=ephemeral_storage,
)
.task_id(attrs.get("metaflow.task_id"))
.environment_variable("AWS_DEFAULT_REGION", self._client.region())
Expand Down Expand Up @@ -353,6 +355,7 @@ def launch_job(
num_parallel=0,
env={},
attrs={},
ephemeral_storage=None,
):
if queue is None:
queue = next(self._client.active_job_queues(), None)
Expand Down Expand Up @@ -390,6 +393,7 @@ def launch_job(
tmpfs_size=tmpfs_size,
tmpfs_path=tmpfs_path,
num_parallel=num_parallel,
ephemeral_storage=ephemeral_storage,
)
self.num_parallel = num_parallel
self.job = job.execute()
Expand Down
8 changes: 8 additions & 0 deletions metaflow/plugins/aws/batch/batch_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,12 @@ def kill(ctx, run_id, user, my_runs):
@click.option("--ubf-context", default=None, type=click.Choice([None, "ubf_control"]))
@click.option("--host-volumes", multiple=True)
@click.option("--efs-volumes", multiple=True)
@click.option(
"--ephemeral-storage",
default=None,
type=int,
help="Ephemeral storage (for AWS Batch only)",
)
@click.option(
"--num-parallel",
default=0,
Expand Down Expand Up @@ -186,6 +192,7 @@ def step(
tmpfs_path=None,
host_volumes=None,
efs_volumes=None,
ephemeral_storage=None,
num_parallel=None,
**kwargs
):
Expand Down Expand Up @@ -317,6 +324,7 @@ def _sync_metadata():
tmpfs_tempdir=tmpfs_tempdir,
tmpfs_size=tmpfs_size,
tmpfs_path=tmpfs_path,
ephemeral_storage=ephemeral_storage,
num_parallel=num_parallel,
)
except Exception as e:
Expand Down
15 changes: 15 additions & 0 deletions metaflow/plugins/aws/batch/batch_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ def _register_job_definition(
tmpfs_size,
tmpfs_path,
num_parallel,
ephemeral_storage,
):
# identify platform from any compute environment associated with the
# queue
Expand Down Expand Up @@ -210,6 +211,10 @@ def _register_job_definition(
job_definition["containerProperties"]["networkConfiguration"] = {
"assignPublicIp": "ENABLED"
}
if ephemeral_storage:
job_definition["containerProperties"]["ephemeralStorage"] = {
"sizeInGiB": ephemeral_storage
}

if platform == "EC2" or platform == "SPOT":
if "linuxParameters" not in job_definition["containerProperties"]:
Expand Down Expand Up @@ -254,6 +259,10 @@ def _register_job_definition(
job_definition["containerProperties"]["linuxParameters"][
"maxSwap"
] = int(max_swap)
if ephemeral_storage:
raise BatchJobException(
"The ephemeral_storage parameter is only available for FARGATE compute environments"
)

if inferentia:
if not (isinstance(inferentia, (int, unicode, basestring))):
Expand Down Expand Up @@ -315,6 +324,10 @@ def _register_job_definition(
{"sourceVolume": name, "containerPath": container_path}
)

if use_tmpfs and (platform == "FARGATE" or platform == "FARGATE_SPOT"):
raise BatchJobException(
"tmpfs is not available for Fargate compute resources"
)
if use_tmpfs or (tmpfs_size and not use_tmpfs):
if tmpfs_size:
if not (isinstance(tmpfs_size, (int, unicode, basestring))):
Expand Down Expand Up @@ -442,6 +455,7 @@ def job_def(
tmpfs_size,
tmpfs_path,
num_parallel,
ephemeral_storage,
):
self.payload["jobDefinition"] = self._register_job_definition(
image,
Expand All @@ -461,6 +475,7 @@ def job_def(
tmpfs_size,
tmpfs_path,
num_parallel,
ephemeral_storage,
)
return self

Expand Down
7 changes: 6 additions & 1 deletion metaflow/plugins/aws/batch/batch_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ class BatchDecorator(StepDecorator):
necessary. A swappiness value of 100 causes pages to be swapped very
aggressively. Accepted values are whole numbers between 0 and 100.
use_tmpfs : bool, default False
This enables an explicit tmpfs mount for this step.
This enables an explicit tmpfs mount for this step. Note that tmpfs is
not available on Fargate compute environments
tmpfs_tempdir : bool, default True
sets METAFLOW_TEMPDIR to tmpfs_path if set for this step.
tmpfs_size : int, optional, default None
Expand All @@ -85,6 +86,9 @@ class BatchDecorator(StepDecorator):
Number of Inferentia chips required for this step.
efa : int, default 0
Number of elastic fabric adapter network devices to attach to container
ephemeral_storage: int, default None
The total amount, in GiB, of ephemeral storage to set for the task (21-200)
This is only relevant for Fargate compute environments
"""

name = "batch"
Expand All @@ -107,6 +111,7 @@ class BatchDecorator(StepDecorator):
"tmpfs_tempdir": True,
"tmpfs_size": None,
"tmpfs_path": "/metaflow_temp",
"ephemeral_storage": None,
}
resource_defaults = {
"cpu": "1",
Expand Down
1 change: 1 addition & 0 deletions metaflow/plugins/aws/step_functions/step_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,6 +839,7 @@ def _batch(self, node):
attrs=attrs,
host_volumes=resources["host_volumes"],
efs_volumes=resources["efs_volumes"],
ephemeral_storage=resources["ephemeral_storage"],
)
.attempts(total_retries + 1)
)
Expand Down

0 comments on commit cbf9b7f

Please sign in to comment.