Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,16 @@
exponential_backoff_retry,
)
from airflow.providers.amazon.aws.hooks.ecs import EcsHook
from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS
from airflow.stats import Stats
from airflow.utils import timezone
from airflow.utils.helpers import merge_dicts
from airflow.utils.state import State

if TYPE_CHECKING:
from sqlalchemy.orm import Session

from airflow.executors import workloads
from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
from airflow.providers.amazon.aws.executors.ecs.utils import (
CommandType,
Expand Down Expand Up @@ -100,6 +104,11 @@ class AwsEcsExecutor(BaseExecutor):
# AWS limits the maximum number of ARNs in the describe_tasks function.
DESCRIBE_TASKS_BATCH_SIZE = 99

if TYPE_CHECKING and AIRFLOW_V_3_0_PLUS:
# In the v3 path, we store workloads, not commands as strings.
# TODO: TaskSDK: move this type change into BaseExecutor
queued_tasks: dict[TaskInstanceKey, workloads.All] # type: ignore[assignment]

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.active_workers: EcsTaskCollection = EcsTaskCollection()
Expand All @@ -114,6 +123,31 @@ def __init__(self, *args, **kwargs):

self.run_task_kwargs = self._load_run_kwargs()

def queue_workload(self, workload: workloads.All, session: Session | None) -> None:
from airflow.executors import workloads

if not isinstance(workload, workloads.ExecuteTask):
raise RuntimeError(f"{type(self)} cannot handle workloads of type {type(workload)}")
ti = workload.ti
self.queued_tasks[ti.key] = workload

def _process_workloads(self, workloads: list[workloads.All]) -> None:
from airflow.executors.workloads import ExecuteTask

# Airflow V3 version
for w in workloads:
if not isinstance(w, ExecuteTask):
raise RuntimeError(f"{type(self)} cannot handle workloads of type {type(w)}")

command = [w]
key = w.ti.key
queue = w.ti.queue
executor_config = w.ti.executor_config or {}

del self.queued_tasks[key]
self.execute_async(key=key, command=command, queue=queue, executor_config=executor_config) # type: ignore[arg-type]
self.running.add(key)

def start(self):
"""Call this when the Executor is run for the first time by the scheduler."""
check_health = conf.getboolean(
Expand Down Expand Up @@ -462,6 +496,24 @@ def execute_async(self, key: TaskInstanceKey, command: CommandType, queue=None,
"""Save the task to be executed in the next sync by inserting the commands into a queue."""
if executor_config and ("name" in executor_config or "command" in executor_config):
raise ValueError('Executor Config should never override "name" or "command"')
if len(command) == 1:
from airflow.executors.workloads import ExecuteTask

if isinstance(command[0], ExecuteTask):
workload = command[0]
ser_input = workload.model_dump_json()
command = [
"python",
"-m",
"airflow.sdk.execution_time.execute_workload",
"--json-string",
ser_input,
]
else:
raise ValueError(
f"EcsExecutor doesn't know how to handle workload of type: {type(command[0])}"
)

self.pending_tasks.append(
EcsQueuedTask(key, command, queue, executor_config or {}, 1, timezone.utcnow())
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@

from tests_common import RUNNING_TESTS_AGAINST_AIRFLOW_PACKAGES
from tests_common.test_utils.config import conf_vars
from tests_common.test_utils.version_compat import AIRFLOW_V_2_10_PLUS
from tests_common.test_utils.version_compat import AIRFLOW_V_2_10_PLUS, AIRFLOW_V_3_0_PLUS

pytestmark = pytest.mark.db_test

Expand Down Expand Up @@ -412,6 +412,96 @@ def test_execute(self, change_state_mock, mock_airflow_key, mock_executor, mock_
airflow_key, TaskInstanceState.RUNNING, ARN1, remove_running=False
)

@pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3+")
@mock.patch("airflow.providers.amazon.aws.executors.ecs.ecs_executor.AwsEcsExecutor.change_state")
def test_task_sdk(self, change_state_mock, mock_airflow_key, mock_executor, mock_cmd):
"""Test task sdk execution from end-to-end."""
from airflow.executors.workloads import ExecuteTask

workload = mock.Mock(spec=ExecuteTask)
workload.ti = mock.Mock(spec=TaskInstance)
workload.ti.key = mock_airflow_key()
tags_exec_config = [{"key": "FOO", "value": "BAR"}]
workload.ti.executor_config = {"tags": tags_exec_config}
ser_workload = json.dumps({"test_key": "test_value"})
workload.model_dump_json.return_value = ser_workload

mock_executor.queue_workload(workload, mock.Mock())

mock_executor.ecs.run_task.return_value = {
"tasks": [
{
"taskArn": ARN1,
"lastStatus": "",
"desiredStatus": "",
"containers": [{"name": "some-ecs-container"}],
}
],
"failures": [],
}

assert mock_executor.queued_tasks[workload.ti.key] == workload
assert len(mock_executor.pending_tasks) == 0
assert len(mock_executor.running) == 0
mock_executor._process_workloads([workload])
assert len(mock_executor.queued_tasks) == 0
assert len(mock_executor.running) == 1
assert workload.ti.key in mock_executor.running
assert len(mock_executor.pending_tasks) == 1
assert mock_executor.pending_tasks[0].command == [
"python",
"-m",
"airflow.sdk.execution_time.execute_workload",
"--json-string",
'{"test_key": "test_value"}',
]

mock_executor.attempt_task_runs()
mock_executor.ecs.run_task.assert_called_once()
assert len(mock_executor.pending_tasks) == 0
mock_executor.ecs.run_task.assert_called_once_with(
cluster="some-cluster",
count=1,
launchType="FARGATE",
platformVersion="LATEST",
taskDefinition="some-task-def",
tags=tags_exec_config,
networkConfiguration={
"awsvpcConfiguration": {
"assignPublicIp": "DISABLED",
"securityGroups": ["sg1", "sg2"],
"subnets": ["sub1", "sub2"],
},
},
overrides={
"containerOverrides": [
{
"command": [
"python",
"-m",
"airflow.sdk.execution_time.execute_workload",
"--json-string",
ser_workload,
],
"environment": [
{
"name": "AIRFLOW_IS_EXECUTOR_CONTAINER",
"value": "true",
},
],
"name": "container-name",
},
],
},
)

# Task is stored in active worker.
assert len(mock_executor.active_workers) == 1
assert ARN1 in mock_executor.active_workers.task_by_key(workload.ti.key).task_arn
change_state_mock.assert_called_once_with(
workload.ti.key, TaskInstanceState.RUNNING, ARN1, remove_running=False
)

@mock.patch.object(ecs_executor, "calculate_next_attempt_delay", return_value=dt.timedelta(seconds=0))
def test_success_execute_api_exception(self, mock_backoff, mock_executor, mock_cmd):
"""Test what happens when ECS throws an exception, but ultimately runs the task."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,7 @@ def run_next(self, next_job: KubernetesJobType) -> None:
"python",
"-m",
"airflow.sdk.execution_time.execute_workload",
"--json-path",
"/tmp/execute/input.json",
]
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ def test_pod_spec_for_task_sdk_runs(self, content_json, expected, data_file):
"python",
"-m",
"airflow.sdk.execution_time.execute_workload",
"--json-path",
"/tmp/execute/input.json",
],
pod_override_object=None,
Expand Down Expand Up @@ -226,7 +227,12 @@ def test_pod_spec_for_task_sdk_runs(self, content_json, expected, data_file):
assert volume == {"emptyDir": {}, "name": "execute-volume"}

main_container = sanitized_result["spec"]["containers"][0]
assert main_container["command"] == ["python", "-m", "airflow.sdk.execution_time.execute_workload"]
assert main_container["command"] == [
"python",
"-m",
"airflow.sdk.execution_time.execute_workload",
"--json-path",
]
assert main_container["args"] == ["/tmp/execute/input.json"]

def test_from_obj_pod_override_object(self):
Expand Down
55 changes: 42 additions & 13 deletions task-sdk/src/airflow/sdk/execution_time/execute_workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,17 @@

import argparse
import sys
from typing import TYPE_CHECKING

import structlog

log = structlog.get_logger(logger_name=__name__)
if TYPE_CHECKING:
from airflow.executors.workloads import ExecuteTask

log = structlog.get_logger(logger_name=__name__)

def execute_workload(input: str) -> None:
from pydantic import TypeAdapter

def execute_workload(workload: ExecuteTask) -> None:
from airflow.configuration import conf
from airflow.executors import workloads
from airflow.sdk.execution_time.supervisor import supervise
Expand All @@ -48,21 +50,20 @@ def execute_workload(input: str) -> None:

configure_logging(output=sys.stdout.buffer, enable_pretty_log=False)

decoder = TypeAdapter[workloads.All](workloads.All)
workload = decoder.validate_json(input)

if not isinstance(workload, workloads.ExecuteTask):
raise ValueError(f"We do not know how to handle {type(workload)}")
raise ValueError(f"Executor does not know how to handle {type(workload)}")

log.info("Executing workload", workload=workload)
server = conf.get("core", "execution_api_server_url")
log.info("Connecting to server:", server=server)

supervise(
# This is the "wrong" ti type, but it duck types the same. TODO: Create a protocol for this.
ti=workload.ti, # type: ignore[arg-type]
dag_rel_path=workload.dag_rel_path,
bundle_info=workload.bundle_info,
token=workload.token,
server=conf.get("core", "execution_api_server_url"),
server=server,
log_path=workload.log_path,
# Include the output of the task to stdout too, so that in process logs can be read from via the
# kubeapi as pod logs.
Expand All @@ -74,16 +75,44 @@ def main():
parser = argparse.ArgumentParser(
description="Execute a workload in a Containerised executor using the task SDK."
)
parser.add_argument(
"input_file", help="Path to the input JSON file containing the execution workload payload."

# Create a mutually exclusive group to ensure that only one of the flags is set
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument(
"--json-path",
help="Path to the input JSON file containing the execution workload payload.",
type=str,
)
group.add_argument(
"--json-string",
help="The JSON string itself containing the execution workload payload.",
type=str,
)

args = parser.parse_args()

with open(args.input_file) as file:
input_data = file.read()
from pydantic import TypeAdapter

execute_workload(input_data)
from airflow.executors import workloads

decoder = TypeAdapter[workloads.All](workloads.All)
if args.json_path:
try:
with open(args.json_path) as file:
input_data = file.read()
workload = decoder.validate_json(input_data)
except Exception as e:
log.error("Failed to read file", error=str(e))
sys.exit(1)

elif args.json_string:
try:
workload = decoder.validate_json(args.json_string)
except Exception as e:
log.error("Failed to parse input JSON string", error=str(e))
sys.exit(1)

execute_workload(workload)


if __name__ == "__main__":
Expand Down