Skip to content

Commit

Permalink
add 'trainium' alias for inferentia to batch decorator
Browse files Browse the repository at this point in the history
  • Loading branch information
saikonen committed Mar 13, 2024
1 parent 0f4ef1a commit d7d4cab
Showing 1 changed file with 18 additions and 0 deletions.
18 changes: 18 additions & 0 deletions metaflow/plugins/aws/batch/batch_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ class BatchDecorator(StepDecorator):
Path to tmpfs mount for this step. Defaults to /metaflow_temp.
inferentia : int, default 0
Number of Inferentia chips required for this step.
trainium : int, default None
Alias for inferentia. Use only one of the two.
efa : int, default 0
Number of elastic fabric adapter network devices to attach to container
ephemeral_storage: int, default None
Expand All @@ -104,6 +106,7 @@ class BatchDecorator(StepDecorator):
"max_swap": None,
"swappiness": None,
"inferentia": None,
"trainium": None, # alias for inferentia
"efa": None,
"host_volumes": None,
"efs_volumes": None,
Expand Down Expand Up @@ -151,6 +154,21 @@ def __init__(self, attributes=None, statically_defined=False):
self.attributes["image"],
)

# Alias trainium to inferentia and check that both are not in use.
if (
self.attributes["inferentia"] is not None
and self.attributes["trainium"] is not None
):
raise BatchException(
"only specify a value for 'inferentia' or 'trainium', not both."
)

if self.attributes["trainium"] is not None:
self.attributes["inferentia"] = self.attributes["trainium"]

# clean up the alias attribute so it is not passed on.
self.attributes.pop("trainium", None)

# Refer https://github.com/Netflix/metaflow/blob/master/docs/lifecycle.png
# to understand where these functions are invoked in the lifecycle of a
# Metaflow flow.
Expand Down

0 comments on commit d7d4cab

Please sign in to comment.