Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,15 @@
exponential_backoff_retry,
)
from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook
from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS
from airflow.stats import Stats
from airflow.utils import timezone
from airflow.utils.helpers import merge_dicts

if TYPE_CHECKING:
from sqlalchemy.orm import Session

from airflow.executors import workloads
from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
from airflow.providers.amazon.aws.executors.batch.boto_schema import (
BatchDescribeJobsResponseSchema,
Expand Down Expand Up @@ -97,6 +101,11 @@ class AwsBatchExecutor(BaseExecutor):
# AWS only allows a maximum number of JOBs in the describe_jobs function
DESCRIBE_JOBS_BATCH_SIZE = 99

if TYPE_CHECKING and AIRFLOW_V_3_0_PLUS:
# In the v3 path, we store workloads, not commands as strings.
# TODO: TaskSDK: move this type change into BaseExecutor
queued_tasks: dict[TaskInstanceKey, workloads.All] # type: ignore[assignment]

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.active_workers = BatchJobCollection()
Expand All @@ -106,6 +115,30 @@ def __init__(self, *args, **kwargs):
self.IS_BOTO_CONNECTION_HEALTHY = False
self.submit_job_kwargs = self._load_submit_kwargs()

def queue_workload(self, workload: workloads.All, session: Session | None) -> None:
from airflow.executors import workloads

if not isinstance(workload, workloads.ExecuteTask):
raise RuntimeError(f"{type(self)} cannot handle workloads of type {type(workload)}")
ti = workload.ti
self.queued_tasks[ti.key] = workload

def _process_workloads(self, workloads: Sequence[workloads.All]) -> None:
from airflow.executors.workloads import ExecuteTask

# Airflow V3 version
for w in workloads:
if not isinstance(w, ExecuteTask):
raise RuntimeError(f"{type(self)} cannot handle workloads of type {type(w)}")
command = [w]
key = w.ti.key
queue = w.ti.queue
executor_config = w.ti.executor_config or {}

del self.queued_tasks[key]
self.execute_async(key=key, command=command, queue=queue, executor_config=executor_config) # type: ignore[arg-type]
self.running.add(key)

def check_health(self):
"""Make a test API call to check the health of the Batch Executor."""
success_status = "succeeded."
Expand Down Expand Up @@ -343,6 +376,24 @@ def execute_async(self, key: TaskInstanceKey, command: CommandType, queue=None,
if executor_config and "command" in executor_config:
raise ValueError('Executor Config should never override "command"')

if len(command) == 1:
from airflow.executors.workloads import ExecuteTask

if isinstance(command[0], ExecuteTask):
workload = command[0]
ser_input = workload.model_dump_json()
command = [
"python",
"-m",
"airflow.sdk.execution_time.execute_workload",
"--json-string",
ser_input,
]
else:
raise ValueError(
f"BatchExecutor doesn't know how to handle workload of type: {type(command[0])}"
)

self.pending_jobs.append(
BatchQueuedJob(
key=key,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@

from tests_common import RUNNING_TESTS_AGAINST_AIRFLOW_PACKAGES
from tests_common.test_utils.config import conf_vars
from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS

airflow_version = VersionInfo(*map(int, airflow_version_str.split(".")[:3]))
ARN1 = "arn1"
Expand Down Expand Up @@ -199,6 +200,74 @@ def test_execute(self, mock_executor):
mock_executor.batch.submit_job.assert_called_once()
assert len(mock_executor.active_workers) == 1

@pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3+")
@mock.patch("airflow.providers.amazon.aws.executors.batch.batch_executor.AwsBatchExecutor.running_state")
def test_task_sdk(self, running_state_mock, mock_airflow_key, mock_executor, mock_cmd):
"""Test task sdk execution from end-to-end."""
from airflow.executors.workloads import ExecuteTask

workload = mock.Mock(spec=ExecuteTask)
workload.ti = mock.Mock(spec=TaskInstance)
workload.ti.key = mock_airflow_key()
tags_exec_config = [{"key": "FOO", "value": "BAR"}]
workload.ti.executor_config = {"tags": tags_exec_config}
ser_workload = json.dumps({"test_key": "test_value"})
workload.model_dump_json.return_value = ser_workload

mock_executor.queue_workload(workload, mock.Mock())

mock_executor.batch.submit_job.return_value = {"jobId": ARN1, "jobName": "some-job-name"}

assert mock_executor.queued_tasks[workload.ti.key] == workload
assert len(mock_executor.pending_jobs) == 0
assert len(mock_executor.running) == 0
mock_executor._process_workloads([workload])
assert len(mock_executor.queued_tasks) == 0
assert len(mock_executor.running) == 1
assert workload.ti.key in mock_executor.running
assert len(mock_executor.pending_jobs) == 1
assert mock_executor.pending_jobs[0].command == [
"python",
"-m",
"airflow.sdk.execution_time.execute_workload",
"--json-string",
'{"test_key": "test_value"}',
]

mock_executor.attempt_submit_jobs()
mock_executor.batch.submit_job.assert_called_once()
assert len(mock_executor.pending_jobs) == 0
mock_executor.batch.submit_job.assert_called_once_with(
jobDefinition="some-job-def",
jobName="some-job-name",
jobQueue="some-job-queue",
tags=tags_exec_config,
containerOverrides={
"command": [
"python",
"-m",
"airflow.sdk.execution_time.execute_workload",
"--json-string",
ser_workload,
],
"environment": [
{
"name": "AIRFLOW_IS_EXECUTOR_CONTAINER",
"value": "true",
},
],
},
)

# Task is stored in active worker.
assert len(mock_executor.active_workers) == 1
# Get the job_id for this task key
job_id = next(
job_id for job_id, key in mock_executor.active_workers.id_to_key.items() if key == workload.ti.key
)
assert job_id == ARN1
running_state_mock.assert_called_once_with(workload.ti.key, ARN1)

@mock.patch.object(batch_executor, "calculate_next_attempt_delay", return_value=dt.timedelta(seconds=0))
def test_attempt_all_jobs_when_some_jobs_fail(self, _, mock_executor):
"""
Expand Down