diff --git a/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py b/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py index 3d9821ae25703..04e88f430df19 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py @@ -49,12 +49,16 @@ exponential_backoff_retry, ) from airflow.providers.amazon.aws.hooks.ecs import EcsHook +from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS from airflow.stats import Stats from airflow.utils import timezone from airflow.utils.helpers import merge_dicts from airflow.utils.state import State if TYPE_CHECKING: + from sqlalchemy.orm import Session + + from airflow.executors import workloads from airflow.models.taskinstance import TaskInstance, TaskInstanceKey from airflow.providers.amazon.aws.executors.ecs.utils import ( CommandType, @@ -100,6 +104,11 @@ class AwsEcsExecutor(BaseExecutor): # AWS limits the maximum number of ARNs in the describe_tasks function. DESCRIBE_TASKS_BATCH_SIZE = 99 + if TYPE_CHECKING and AIRFLOW_V_3_0_PLUS: + # In the v3 path, we store workloads, not commands as strings. + # TODO: TaskSDK: move this type change into BaseExecutor + queued_tasks: dict[TaskInstanceKey, workloads.All] # type: ignore[assignment] + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.active_workers: EcsTaskCollection = EcsTaskCollection() @@ -114,6 +123,31 @@ def __init__(self, *args, **kwargs): self.run_task_kwargs = self._load_run_kwargs() + def queue_workload(self, workload: workloads.All, session: Session | None) -> None: + from airflow.executors import workloads + + if not isinstance(workload, workloads.ExecuteTask): + raise RuntimeError(f"{type(self)} cannot handle workloads of type {type(workload)}") + ti = workload.ti + self.queued_tasks[ti.key] = workload + + def _process_workloads(self, workloads: list[workloads.All]) -> None: + from airflow.executors.workloads import ExecuteTask + + # Airflow V3 version + for w in workloads: + if not isinstance(w, ExecuteTask): + raise RuntimeError(f"{type(self)} cannot handle workloads of type {type(w)}") + + command = [w] + key = w.ti.key + queue = w.ti.queue + executor_config = w.ti.executor_config or {} + + del self.queued_tasks[key] + self.execute_async(key=key, command=command, queue=queue, executor_config=executor_config) # type: ignore[arg-type] + self.running.add(key) + def start(self): """Call this when the Executor is run for the first time by the scheduler.""" check_health = conf.getboolean( @@ -462,6 +496,24 @@ def execute_async(self, key: TaskInstanceKey, command: CommandType, queue=None, """Save the task to be executed in the next sync by inserting the commands into a queue.""" if executor_config and ("name" in executor_config or "command" in executor_config): raise ValueError('Executor Config should never override "name" or "command"') + if len(command) == 1: + from airflow.executors.workloads import ExecuteTask + + if isinstance(command[0], ExecuteTask): + workload = command[0] + ser_input = workload.model_dump_json() + command = [ + "python", + "-m", + "airflow.sdk.execution_time.execute_workload", + "--json-string", + ser_input, + ] + else: + raise ValueError( + f"EcsExecutor doesn't know how to handle workload of type: {type(command[0])}" + ) + self.pending_tasks.append( EcsQueuedTask(key, command, queue, executor_config or {}, 1, timezone.utcnow()) ) diff --git a/providers/amazon/tests/unit/amazon/aws/executors/ecs/test_ecs_executor.py b/providers/amazon/tests/unit/amazon/aws/executors/ecs/test_ecs_executor.py index eae0b5346892c..074d2babd8524 100644 --- a/providers/amazon/tests/unit/amazon/aws/executors/ecs/test_ecs_executor.py +++ b/providers/amazon/tests/unit/amazon/aws/executors/ecs/test_ecs_executor.py @@ -59,7 +59,7 @@ from tests_common import RUNNING_TESTS_AGAINST_AIRFLOW_PACKAGES from tests_common.test_utils.config import conf_vars -from tests_common.test_utils.version_compat import AIRFLOW_V_2_10_PLUS +from tests_common.test_utils.version_compat import AIRFLOW_V_2_10_PLUS, AIRFLOW_V_3_0_PLUS pytestmark = pytest.mark.db_test @@ -412,6 +412,96 @@ def test_execute(self, change_state_mock, mock_airflow_key, mock_executor, mock_ airflow_key, TaskInstanceState.RUNNING, ARN1, remove_running=False ) + @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3+") + @mock.patch("airflow.providers.amazon.aws.executors.ecs.ecs_executor.AwsEcsExecutor.change_state") + def test_task_sdk(self, change_state_mock, mock_airflow_key, mock_executor, mock_cmd): + """Test task sdk execution from end-to-end.""" + from airflow.executors.workloads import ExecuteTask + + workload = mock.Mock(spec=ExecuteTask) + workload.ti = mock.Mock(spec=TaskInstance) + workload.ti.key = mock_airflow_key() + tags_exec_config = [{"key": "FOO", "value": "BAR"}] + workload.ti.executor_config = {"tags": tags_exec_config} + ser_workload = json.dumps({"test_key": "test_value"}) + workload.model_dump_json.return_value = ser_workload + + mock_executor.queue_workload(workload, mock.Mock()) + + mock_executor.ecs.run_task.return_value = { + "tasks": [ + { + "taskArn": ARN1, + "lastStatus": "", + "desiredStatus": "", + "containers": [{"name": "some-ecs-container"}], + } + ], + "failures": [], + } + + assert mock_executor.queued_tasks[workload.ti.key] == workload + assert len(mock_executor.pending_tasks) == 0 + assert len(mock_executor.running) == 0 + mock_executor._process_workloads([workload]) + assert len(mock_executor.queued_tasks) == 0 + assert len(mock_executor.running) == 1 + assert workload.ti.key in mock_executor.running + assert len(mock_executor.pending_tasks) == 1 + assert mock_executor.pending_tasks[0].command == [ + "python", + "-m", + "airflow.sdk.execution_time.execute_workload", + "--json-string", + '{"test_key": "test_value"}', + ] + + mock_executor.attempt_task_runs() + mock_executor.ecs.run_task.assert_called_once() + assert len(mock_executor.pending_tasks) == 0 + mock_executor.ecs.run_task.assert_called_once_with( + cluster="some-cluster", + count=1, + launchType="FARGATE", + platformVersion="LATEST", + taskDefinition="some-task-def", + tags=tags_exec_config, + networkConfiguration={ + "awsvpcConfiguration": { + "assignPublicIp": "DISABLED", + "securityGroups": ["sg1", "sg2"], + "subnets": ["sub1", "sub2"], + }, + }, + overrides={ + "containerOverrides": [ + { + "command": [ + "python", + "-m", + "airflow.sdk.execution_time.execute_workload", + "--json-string", + ser_workload, + ], + "environment": [ + { + "name": "AIRFLOW_IS_EXECUTOR_CONTAINER", + "value": "true", + }, + ], + "name": "container-name", + }, + ], + }, + ) + + # Task is stored in active worker. + assert len(mock_executor.active_workers) == 1 + assert ARN1 in mock_executor.active_workers.task_by_key(workload.ti.key).task_arn + change_state_mock.assert_called_once_with( + workload.ti.key, TaskInstanceState.RUNNING, ARN1, remove_running=False + ) + @mock.patch.object(ecs_executor, "calculate_next_attempt_delay", return_value=dt.timedelta(seconds=0)) def test_success_execute_api_exception(self, mock_backoff, mock_executor, mock_cmd): """Test what happens when ECS throws an exception, but ultimately runs the task.""" diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_utils.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_utils.py index a59f1e47d1ba7..ceef834428928 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_utils.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_utils.py @@ -402,6 +402,7 @@ def run_next(self, next_job: KubernetesJobType) -> None: "python", "-m", "airflow.sdk.execution_time.execute_workload", + "--json-path", "/tmp/execute/input.json", ] else: diff --git a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/test_pod_generator.py b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/test_pod_generator.py index e4c9db066882f..efb826b3ea6de 100644 --- a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/test_pod_generator.py +++ b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/test_pod_generator.py @@ -196,6 +196,7 @@ def test_pod_spec_for_task_sdk_runs(self, content_json, expected, data_file): "python", "-m", "airflow.sdk.execution_time.execute_workload", + "--json-path", "/tmp/execute/input.json", ], pod_override_object=None, @@ -226,7 +227,12 @@ def test_pod_spec_for_task_sdk_runs(self, content_json, expected, data_file): assert volume == {"emptyDir": {}, "name": "execute-volume"} main_container = sanitized_result["spec"]["containers"][0] - assert main_container["command"] == ["python", "-m", "airflow.sdk.execution_time.execute_workload"] + assert main_container["command"] == [ + "python", + "-m", + "airflow.sdk.execution_time.execute_workload", + "--json-path", + ] assert main_container["args"] == ["/tmp/execute/input.json"] def test_from_obj_pod_override_object(self): diff --git a/task-sdk/src/airflow/sdk/execution_time/execute_workload.py b/task-sdk/src/airflow/sdk/execution_time/execute_workload.py index 5fd9d6669b763..8a6a9d0963a60 100644 --- a/task-sdk/src/airflow/sdk/execution_time/execute_workload.py +++ b/task-sdk/src/airflow/sdk/execution_time/execute_workload.py @@ -29,15 +29,17 @@ import argparse import sys +from typing import TYPE_CHECKING import structlog -log = structlog.get_logger(logger_name=__name__) +if TYPE_CHECKING: + from airflow.executors.workloads import ExecuteTask +log = structlog.get_logger(logger_name=__name__) -def execute_workload(input: str) -> None: - from pydantic import TypeAdapter +def execute_workload(workload: ExecuteTask) -> None: from airflow.configuration import conf from airflow.executors import workloads from airflow.sdk.execution_time.supervisor import supervise @@ -48,13 +50,12 @@ def execute_workload(input: str) -> None: configure_logging(output=sys.stdout.buffer, enable_pretty_log=False) - decoder = TypeAdapter[workloads.All](workloads.All) - workload = decoder.validate_json(input) - if not isinstance(workload, workloads.ExecuteTask): - raise ValueError(f"We do not know how to handle {type(workload)}") + raise ValueError(f"Executor does not know how to handle {type(workload)}") log.info("Executing workload", workload=workload) + server = conf.get("core", "execution_api_server_url") + log.info("Connecting to server:", server=server) supervise( # This is the "wrong" ti type, but it duck types the same. TODO: Create a protocol for this. @@ -62,7 +63,7 @@ def execute_workload(input: str) -> None: dag_rel_path=workload.dag_rel_path, bundle_info=workload.bundle_info, token=workload.token, - server=conf.get("core", "execution_api_server_url"), + server=server, log_path=workload.log_path, # Include the output of the task to stdout too, so that in process logs can be read from via the # kubeapi as pod logs. @@ -74,16 +75,44 @@ def main(): parser = argparse.ArgumentParser( description="Execute a workload in a Containerised executor using the task SDK." ) - parser.add_argument( - "input_file", help="Path to the input JSON file containing the execution workload payload." + + # Create a mutually exclusive group to ensure that only one of the flags is set + group = parser.add_mutually_exclusive_group(required=True) + group.add_argument( + "--json-path", + help="Path to the input JSON file containing the execution workload payload.", + type=str, + ) + group.add_argument( + "--json-string", + help="The JSON string itself containing the execution workload payload.", + type=str, ) args = parser.parse_args() - with open(args.input_file) as file: - input_data = file.read() + from pydantic import TypeAdapter - execute_workload(input_data) + from airflow.executors import workloads + + decoder = TypeAdapter[workloads.All](workloads.All) + if args.json_path: + try: + with open(args.json_path) as file: + input_data = file.read() + workload = decoder.validate_json(input_data) + except Exception as e: + log.error("Failed to read file", error=str(e)) + sys.exit(1) + + elif args.json_string: + try: + workload = decoder.validate_json(args.json_string) + except Exception as e: + log.error("Failed to parse input JSON string", error=str(e)) + sys.exit(1) + + execute_workload(workload) if __name__ == "__main__":