Skip to content

Commit

Permalink
ECS Overrides for AWS Batch submit_job
Browse files Browse the repository at this point in the history
  • Loading branch information
yehoshuadimarsky committed May 29, 2024
1 parent fa47f74 commit 02d14c8
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 5 deletions.
3 changes: 3 additions & 0 deletions airflow/providers/amazon/aws/hooks/batch_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def submit_job(
arrayProperties: dict,
parameters: dict,
containerOverrides: dict,
ecsPropertiesOverride: dict,
tags: dict,
) -> dict:
"""
Expand All @@ -119,6 +120,8 @@ def submit_job(
:param containerOverrides: the same parameter that boto3 will receive
:param ecsPropertiesOverride: the same parameter that boto3 will receive
:param tags: the same parameter that boto3 will receive
:return: an API response
Expand Down
8 changes: 8 additions & 0 deletions airflow/providers/amazon/aws/operators/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class BatchOperator(BaseOperator):
:param job_queue: the queue name on AWS Batch
:param overrides: DEPRECATED, use container_overrides instead with the same value.
:param container_overrides: the `containerOverrides` parameter for boto3 (templated)
:param ecs_properties_override: the `ecsPropertiesOverride` parameter for boto3 (templated)
:param node_overrides: the `nodeOverrides` parameter for boto3 (templated)
:param share_identifier: The share identifier for the job. Don't specify this parameter if the job queue
doesn't have a scheduling policy.
Expand Down Expand Up @@ -112,6 +113,7 @@ class BatchOperator(BaseOperator):
"job_queue",
"container_overrides",
"array_properties",
"ecs_properties_override",
"node_overrides",
"parameters",
"retry_strategy",
Expand All @@ -124,6 +126,7 @@ class BatchOperator(BaseOperator):
template_fields_renderers = {
"container_overrides": "json",
"parameters": "json",
"ecs_properties_override": "json",
"node_overrides": "json",
"retry_strategy": "json",
}
Expand Down Expand Up @@ -160,6 +163,7 @@ def __init__(
overrides: dict | None = None, # deprecated
container_overrides: dict | None = None,
array_properties: dict | None = None,
ecs_properties_override: dict | None = None,
node_overrides: dict | None = None,
share_identifier: str | None = None,
scheduling_priority_override: int | None = None,
Expand Down Expand Up @@ -201,6 +205,7 @@ def __init__(
stacklevel=2,
)

self.ecs_properties_override = ecs_properties_override
self.node_overrides = node_overrides
self.share_identifier = share_identifier
self.scheduling_priority_override = scheduling_priority_override
Expand Down Expand Up @@ -296,6 +301,8 @@ def submit_job(self, context: Context):
self.log.info("AWS Batch job - container overrides: %s", self.container_overrides)
if self.array_properties:
self.log.info("AWS Batch job - array properties: %s", self.array_properties)
if self.ecs_properties_override:
self.log.info("AWS Batch job - ECS properties: %s", self.ecs_properties_override)
if self.node_overrides:
self.log.info("AWS Batch job - node properties: %s", self.node_overrides)

Expand All @@ -307,6 +314,7 @@ def submit_job(self, context: Context):
"parameters": self.parameters,
"tags": self.tags,
"containerOverrides": self.container_overrides,
"ecsPropertiesOverride": self.ecs_properties_override,
"nodeOverrides": self.node_overrides,
"retryStrategy": self.retry_strategy,
"shareIdentifier": self.share_identifier,
Expand Down
73 changes: 68 additions & 5 deletions tests/providers/amazon/aws/operators/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def test_init_defaults(self):
assert batch_job.retry_strategy is None
assert batch_job.container_overrides is None
assert batch_job.array_properties is None
assert batch_job.ecs_properties_override is None
assert batch_job.node_overrides is None
assert batch_job.share_identifier is None
assert batch_job.scheduling_priority_override is None
Expand All @@ -149,6 +150,7 @@ def test_template_fields_overrides(self):
"job_queue",
"container_overrides",
"array_properties",
"ecs_properties_override",
"node_overrides",
"parameters",
"retry_strategy",
Expand Down Expand Up @@ -204,6 +206,62 @@ def test_execute_with_failures(self):
tags={},
)

@mock.patch.object(BatchClientHook, "get_job_description")
@mock.patch.object(BatchClientHook, "wait_for_job")
@mock.patch.object(BatchClientHook, "check_job_success")
def test_execute_with_ecs_overrides(self, check_mock, wait_mock, job_description_mock):
self.batch.container_overrides = None
self.batch.ecs_properties_override = {
"taskProperties": [
{
"containers": [
{
"command": [
"string",
],
"environment": [
{"name": "string", "value": "string"},
],
"name": "string",
"resourceRequirements": [
{"value": "string", "type": "'GPU'|'VCPU'|'MEMORY'"},
],
},
]
},
]
}
self.batch.execute(self.mock_context)

self.client_mock.submit_job.assert_called_once_with(
jobQueue="queue",
jobName=JOB_NAME,
jobDefinition="hello-world",
ecsPropertiesOverride={
"taskProperties": [
{
"containers": [
{
"command": [
"string",
],
"environment": [
{"name": "string", "value": "string"},
],
"name": "string",
"resourceRequirements": [
{"value": "string", "type": "'GPU'|'VCPU'|'MEMORY'"},
],
},
]
},
]
},
parameters={},
retryStrategy={"attempts": 1},
tags={},
)

@mock.patch.object(BatchClientHook, "check_job_success")
def test_wait_job_complete_using_waiters(self, check_mock):
mock_waiters = mock.Mock()
Expand Down Expand Up @@ -238,7 +296,7 @@ def test_kill_job(self):
self.batch.on_kill()
self.client_mock.terminate_job.assert_called_once_with(jobId=JOB_ID, reason="Task killed by the user")

@pytest.mark.parametrize("override", ["overrides", "node_overrides"])
@pytest.mark.parametrize("override", ["overrides", "node_overrides", "ecs_properties_override"])
@patch(
"airflow.providers.amazon.aws.hooks.batch_client.BatchClientHook.client",
new_callable=mock.PropertyMock,
Expand Down Expand Up @@ -269,10 +327,15 @@ def test_override_not_sent_if_not_set(self, client_mock, override):
"parameters": {},
"tags": {},
}
if override == "overrides":
expected_args["containerOverrides"] = {"a": "a"}
else:
expected_args["nodeOverrides"] = {"a": "a"}

py2api = {
"overrides": "containerOverrides",
"node_overrides": "nodeOverrides",
"ecs_properties_override": "ecsPropertiesOverride",
}

expected_args[py2api[override]] = {"a": "a"}

client_mock().submit_job.assert_called_once_with(**expected_args)

def test_deprecated_override_param(self):
Expand Down

0 comments on commit 02d14c8

Please sign in to comment.