diff --git a/airflow-core/src/airflow/executors/base_executor.py b/airflow-core/src/airflow/executors/base_executor.py index 308c9b1d01797..a59231989ecbf 100644 --- a/airflow-core/src/airflow/executors/base_executor.py +++ b/airflow-core/src/airflow/executors/base_executor.py @@ -29,12 +29,13 @@ from airflow.cli.cli_config import DefaultHelpParser from airflow.configuration import conf +from airflow.executors import workloads from airflow.executors.executor_loader import ExecutorLoader from airflow.models import Log from airflow.stats import Stats from airflow.traces import NO_TRACE_ID from airflow.traces.tracer import Trace, add_span, gen_context -from airflow.traces.utils import gen_span_id_from_ti_key, gen_trace_id +from airflow.traces.utils import gen_span_id_from_ti_key from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.state import TaskInstanceState from airflow.utils.thread_safe_dict import ThreadSafeDict @@ -51,28 +52,14 @@ from airflow.callbacks.base_callback_sink import BaseCallbackSink from airflow.callbacks.callback_requests import CallbackRequest from airflow.cli.cli_config import GroupCommand - from airflow.executors import workloads from airflow.executors.executor_utils import ExecutorName from airflow.models.taskinstance import TaskInstance from airflow.models.taskinstancekey import TaskInstanceKey - # Command to execute - list of strings - # the first element is always "airflow". - # It should be result of TaskInstance.generate_command method. - CommandType = Sequence[str] - - # Task that is queued. It contains all the information that is - # needed to run the task. - # - # Tuple of: command, priority, queue name, TaskInstance - QueuedTaskInstanceType = tuple[CommandType, int, Optional[str], TaskInstance] - # Event_buffer dict value type # Tuple of: state, info EventBufferValueType = tuple[Optional[str], Any] - # Task tuple to send to be executed - TaskTuple = tuple[TaskInstanceKey, CommandType, Optional[str], Optional[Any]] log = logging.getLogger(__name__) @@ -159,7 +146,7 @@ def __init__(self, parallelism: int = PARALLELISM, team_id: str | None = None): self.parallelism: int = parallelism self.team_id: str | None = team_id - self.queued_tasks: dict[TaskInstanceKey, QueuedTaskInstanceType] = {} + self.queued_tasks: dict[TaskInstanceKey, workloads.ExecuteTask] = {} self.running: set[TaskInstanceKey] = set() self.event_buffer: dict[TaskInstanceKey, EventBufferValueType] = {} self._task_event_logs: deque[Log] = deque() @@ -192,62 +179,23 @@ def log_task_event(self, *, event: str, extra: str, ti_key: TaskInstanceKey): """Add an event to the log table.""" self._task_event_logs.append(Log(event=event, task_instance=ti_key, extra=extra)) - def queue_command( - self, - task_instance: TaskInstance, - command: CommandType, - priority: int = 1, - queue: str | None = None, - ): - """Queues command to task.""" - if task_instance.key not in self.queued_tasks: - self.log.info("Adding to queue: %s", command) - self.queued_tasks[task_instance.key] = (command, priority, queue, task_instance) - else: - self.log.error("could not queue task %s", task_instance.key) - def queue_workload(self, workload: workloads.All, session: Session) -> None: - raise ValueError(f"Un-handled workload kind {type(workload).__name__!r} in {type(self).__name__}") - - def queue_task_instance( - self, - task_instance: TaskInstance, - mark_success: bool = False, - ignore_all_deps: bool = False, - ignore_depends_on_past: bool = False, - wait_for_past_depends_before_skipping: bool = False, - ignore_task_deps: bool = False, - ignore_ti_state: bool = False, - pool: str | None = None, - cfg_path: str | None = None, - ) -> None: - """Queues task instance.""" - if TYPE_CHECKING: - assert task_instance.task - - pool = pool or task_instance.pool - - command_list_to_run = task_instance.command_as_list( - local=True, - mark_success=mark_success, - ignore_all_deps=ignore_all_deps, - ignore_depends_on_past=ignore_depends_on_past, - wait_for_past_depends_before_skipping=wait_for_past_depends_before_skipping, - ignore_task_deps=ignore_task_deps, - ignore_ti_state=ignore_ti_state, - pool=pool, - # cfg_path is needed to propagate the config values if using impersonation - # (run_as_user), given that there are different code paths running tasks. - # https://github.com/apache/airflow/pull/2991 - cfg_path=cfg_path, - ) - self.log.debug("created command %s", command_list_to_run) - self.queue_command( - task_instance, - command_list_to_run, - priority=task_instance.priority_weight, - queue=task_instance.task.queue, - ) + if not isinstance(workload, workloads.ExecuteTask): + raise ValueError(f"Un-handled workload kind {type(workload).__name__!r} in {type(self).__name__}") + ti = workload.ti + self.queued_tasks[ti.key] = workload + + def _process_workloads(self, workloads: Sequence[workloads.All]) -> None: + """ + Process the given workloads. + + This method must be implemented by subclasses to define how they handle + the execution of workloads (e.g., queuing them to workers, submitting to + external systems, etc.). + + :param workloads: List of workloads to process + """ + raise NotImplementedError(f"{type(self).__name__} must implement _process_workloads()") def has_task(self, task_instance: TaskInstance) -> bool: """ @@ -347,29 +295,19 @@ def _emit_metrics(self, open_slots, num_running_tasks, num_queued_tasks): tags={"status": "running", "name": name}, ) - def order_queued_tasks_by_priority(self) -> list[tuple[TaskInstanceKey, QueuedTaskInstanceType]]: + def order_queued_tasks_by_priority(self) -> list[tuple[TaskInstanceKey, workloads.ExecuteTask]]: """ Orders the queued tasks by priority. - :return: List of tuples from the queued_tasks according to the priority. + :return: List of workloads from the queued_tasks according to the priority. """ - from airflow.executors import workloads - if not self.queued_tasks: return [] - kind = next(iter(self.queued_tasks.values())) - if isinstance(kind, workloads.BaseWorkload): - # V3 + new executor that supports workloads - return sorted( - self.queued_tasks.items(), - key=lambda x: x[1].ti.priority_weight, - reverse=True, - ) - + # V3 + new executor that supports workloads return sorted( self.queued_tasks.items(), - key=lambda x: x[1][1], + key=lambda x: x[1].ti.priority_weight, reverse=True, ) @@ -381,7 +319,6 @@ def trigger_tasks(self, open_slots: int) -> None: :param open_slots: Number of open slots """ sorted_queue = self.order_queued_tasks_by_priority() - task_tuples = [] workload_list = [] for _ in range(min((open_slots, len(self.queued_tasks)))): @@ -397,103 +334,38 @@ def trigger_tasks(self, open_slots: int) -> None: # deferred task has completed. In this case and for this reason, # we make a small number of attempts to see if the task has been # removed from the running set in the meantime. - if key in self.running: - attempt = self.attempts[key] - if attempt.can_try_again(): - # if it hasn't been much time since first check, let it be checked again next time - self.log.info("queued but still running; attempt=%s task=%s", attempt.total_tries, key) - continue - - # Otherwise, we give up and remove the task from the queue. - self.log.error( - "could not queue task %s (still running after %d attempts).", - key, - attempt.total_tries, - ) - self.log_task_event( - event="task launch failure", - extra=( - "Task was in running set and could not be queued " - f"after {attempt.total_tries} attempts." - ), - ti_key=key, - ) + if key in self.attempts: del self.attempts[key] - del self.queued_tasks[key] - else: - if key in self.attempts: - del self.attempts[key] - # TODO: TaskSDK: Compat, remove when KubeExecutor is fully moved over to TaskSDK too. - # TODO: TaskSDK: We need to minimum version requirements on executors with Airflow 3. - # How/where do we do that? Executor loader? - from airflow.executors import workloads - - if isinstance(item, workloads.ExecuteTask) and hasattr(item, "ti"): - ti = item.ti - - # If it's None, then the span for the current TaskInstanceKey hasn't been started. - if self.active_spans is not None and self.active_spans.get(key) is None: - from airflow.models.taskinstance import SimpleTaskInstance - - if isinstance(ti, (SimpleTaskInstance, workloads.TaskInstance)): - parent_context = Trace.extract(ti.parent_context_carrier) - else: - parent_context = Trace.extract(ti.dag_run.context_carrier) - # Start a new span using the context from the parent. - # Attributes will be set once the task has finished so that all - # values will be available (end_time, duration, etc.). - - span = Trace.start_child_span( - span_name=f"{ti.task_id}", - parent_context=parent_context, - component="task", - start_as_current=False, - ) - self.active_spans.set(key, span) - # Inject the current context into the carrier. - carrier = Trace.inject() - ti.context_carrier = carrier - - if hasattr(self, "_process_workloads"): - workload_list.append(item) - else: - (command, _, queue, ti) = item - task_tuples.append((key, command, queue, getattr(ti, "executor_config", None))) - - if task_tuples: - self._process_tasks(task_tuples) - elif workload_list: - self._process_workloads(workload_list) # type: ignore[attr-defined] - @add_span - def _process_tasks(self, task_tuples: list[TaskTuple]) -> None: - for key, command, queue, executor_config in task_tuples: - task_instance = self.queued_tasks[key][3] # TaskInstance in fourth element - trace_id = int(gen_trace_id(task_instance.dag_run, as_int=True)) - span_id = int(gen_span_id_from_ti_key(key, as_int=True)) - links = [{"trace_id": trace_id, "span_id": span_id}] + if isinstance(item, workloads.ExecuteTask) and hasattr(item, "ti"): + ti = item.ti - # assuming that the span_id will very likely be unique inside the trace - with Trace.start_span( - span_name=f"{key.dag_id}.{key.task_id}", - component="BaseExecutor", - span_id=span_id, - links=links, - ) as span: - span.set_attributes( - { - "dag_id": key.dag_id, - "run_id": key.run_id, - "task_id": key.task_id, - "try_number": key.try_number, - "command": str(command), - "queue": str(queue), - "executor_config": str(executor_config), - } - ) - del self.queued_tasks[key] - self.execute_async(key=key, command=command, queue=queue, executor_config=executor_config) - self.running.add(key) + # If it's None, then the span for the current TaskInstanceKey hasn't been started. + if self.active_spans is not None and self.active_spans.get(key) is None: + from airflow.models.taskinstance import SimpleTaskInstance + + if isinstance(ti, (SimpleTaskInstance, workloads.TaskInstance)): + parent_context = Trace.extract(ti.parent_context_carrier) + else: + parent_context = Trace.extract(ti.dag_run.context_carrier) + # Start a new span using the context from the parent. + # Attributes will be set once the task has finished so that all + # values will be available (end_time, duration, etc.). + + span = Trace.start_child_span( + span_name=f"{ti.task_id}", + parent_context=parent_context, + component="task", + start_as_current=False, + ) + self.active_spans.set(key, span) + # Inject the current context into the carrier. + carrier = Trace.inject() + ti.context_carrier = carrier + + workload_list.append(item) + if workload_list: + self._process_workloads(workload_list) # TODO: This should not be using `TaskInstanceState` here, this is just "did the process complete, or did # it die". It is possible for the task itself to finish with success, but the state of the task to be set @@ -609,23 +481,6 @@ def get_event_buffer(self, dag_ids=None) -> dict[TaskInstanceKey, EventBufferVal return cleared_events - def execute_async( - self, - key: TaskInstanceKey, - command: CommandType, - queue: str | None = None, - executor_config: Any | None = None, - ) -> None: # pragma: no cover - """ - Execute the command asynchronously. - - :param key: Unique key for the task instance - :param command: Command to run - :param queue: name of the queue - :param executor_config: Configuration passed to the executor. - """ - raise NotImplementedError() - def get_task_log(self, ti: TaskInstance, try_number: int) -> tuple[list[str], list[str]]: """ Return the task logs. @@ -683,28 +538,6 @@ def slots_occupied(self): """Number of tasks this executor instance is currently managing.""" return len(self.running) + len(self.queued_tasks) - @staticmethod - def validate_airflow_tasks_run_command(command: Sequence[str]) -> tuple[str | None, str | None]: - """ - Check if the command to execute is airflow command. - - Returns tuple (dag_id,task_id) retrieved from the command (replaced with None values if missing) - """ - if command[0:3] != ["airflow", "tasks", "run"]: - raise ValueError('The command must start with ["airflow", "tasks", "run"].') - if len(command) > 3 and "--help" not in command: - dag_id: str | None = None - task_id: str | None = None - for arg in command[3:]: - if not arg.startswith("--"): - if dag_id is None: - dag_id = arg - else: - task_id = arg - break - return dag_id, task_id - return None, None - def debug_dump(self): """Get called in response to SIGUSR2 by the scheduler.""" self.log.info( diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py b/airflow-core/src/airflow/jobs/scheduler_job_runner.py index 586a2053b83eb..e5857aecce90b 100644 --- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py +++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py @@ -42,7 +42,6 @@ from airflow.configuration import conf from airflow.dag_processing.bundles.base import BundleUsageTrackingManager from airflow.executors import workloads -from airflow.executors.base_executor import BaseExecutor from airflow.executors.executor_loader import ExecutorLoader from airflow.jobs.base_job_runner import BaseJobRunner from airflow.jobs.job import Job, perform_heartbeat @@ -88,6 +87,7 @@ from pendulum.datetime import DateTime from sqlalchemy.orm import Query, Session + from airflow.executors.base_executor import BaseExecutor from airflow.executors.executor_utils import ExecutorName from airflow.models.taskinstance import TaskInstanceKey from airflow.utils.sqlalchemy import ( @@ -698,29 +698,8 @@ def _enqueue_task_instances_with_queued_state( ti.set_state(None, session=session) continue - # TODO: Task-SDK: This check is transitionary. Remove once all executors are ported over. - # Has a real queue_activity implemented - if executor.queue_workload.__func__ is not BaseExecutor.queue_workload: # type: ignore[attr-defined] - workload = workloads.ExecuteTask.make(ti, generator=executor.jwt_generator) - executor.queue_workload(workload, session=session) - continue - - command = ti.command_as_list( - local=True, - ) - - priority = ti.priority_weight - queue = ti.queue - self.log.info( - "Sending %s to %s with priority %s and queue %s", ti.key, executor.name, priority, queue - ) - - executor.queue_command( - ti, - command, - priority=priority, - queue=queue, - ) + workload = workloads.ExecuteTask.make(ti, generator=executor.jwt_generator) + executor.queue_workload(workload, session=session) def _critical_section_enqueue_task_instances(self, session: Session) -> int: """ diff --git a/airflow-core/src/airflow/models/taskinstance.py b/airflow-core/src/airflow/models/taskinstance.py index bc16db4092908..b49f5e5f8b004 100644 --- a/airflow-core/src/airflow/models/taskinstance.py +++ b/airflow-core/src/airflow/models/taskinstance.py @@ -28,7 +28,6 @@ from collections.abc import Collection, Generator, Iterable, Sequence from datetime import timedelta from functools import cache -from pathlib import Path from typing import TYPE_CHECKING, Any from urllib.parse import quote @@ -110,7 +109,6 @@ if TYPE_CHECKING: from datetime import datetime - from pathlib import PurePath import pendulum from sqlalchemy.engine import Connection as SAConnection, Engine @@ -763,158 +761,6 @@ def to_runtime_ti(self, context_from_server) -> RuntimeTaskInstanceProtocol: return runtime_ti - @staticmethod - def _command_as_list( - ti: TaskInstance, - mark_success: bool = False, - ignore_all_deps: bool = False, - ignore_task_deps: bool = False, - ignore_depends_on_past: bool = False, - wait_for_past_depends_before_skipping: bool = False, - ignore_ti_state: bool = False, - local: bool = False, - raw: bool = False, - pool: str | None = None, - cfg_path: str | None = None, - ) -> list[str]: - dag: DAG | DagModel | None - # Use the dag if we have it, else fallback to the ORM dag_model, which might not be loaded - if hasattr(ti, "task") and getattr(ti.task, "dag", None) is not None: - if TYPE_CHECKING: - assert ti.task - assert isinstance(ti.task.dag, SchedulerDAG) - dag = ti.task.dag - else: - dag = ti.dag_model - - if dag is None: - raise ValueError("DagModel is empty") - - path = None - if dag.relative_fileloc: - path = Path(dag.relative_fileloc) - - if path: - if not path.is_absolute(): - path = "DAGS_FOLDER" / path - - return TaskInstance.generate_command( - ti.dag_id, - ti.task_id, - run_id=ti.run_id, - mark_success=mark_success, - ignore_all_deps=ignore_all_deps, - ignore_task_deps=ignore_task_deps, - ignore_depends_on_past=ignore_depends_on_past, - wait_for_past_depends_before_skipping=wait_for_past_depends_before_skipping, - ignore_ti_state=ignore_ti_state, - local=local, - file_path=path, - raw=raw, - pool=pool, - cfg_path=cfg_path, - map_index=ti.map_index, - ) - - def command_as_list( - self, - mark_success: bool = False, - ignore_all_deps: bool = False, - ignore_task_deps: bool = False, - ignore_depends_on_past: bool = False, - wait_for_past_depends_before_skipping: bool = False, - ignore_ti_state: bool = False, - local: bool = False, - raw: bool = False, - pool: str | None = None, - cfg_path: str | None = None, - ) -> list[str]: - """ - Return a command that can be executed anywhere where airflow is installed. - - This command is part of the message sent to executors by the orchestrator. - """ - return TaskInstance._command_as_list( - ti=self, - mark_success=mark_success, - ignore_all_deps=ignore_all_deps, - ignore_task_deps=ignore_task_deps, - ignore_depends_on_past=ignore_depends_on_past, - wait_for_past_depends_before_skipping=wait_for_past_depends_before_skipping, - ignore_ti_state=ignore_ti_state, - local=local, - raw=raw, - pool=pool, - cfg_path=cfg_path, - ) - - @staticmethod - def generate_command( - dag_id: str, - task_id: str, - run_id: str, - mark_success: bool = False, - ignore_all_deps: bool = False, - ignore_depends_on_past: bool = False, - wait_for_past_depends_before_skipping: bool = False, - ignore_task_deps: bool = False, - ignore_ti_state: bool = False, - local: bool = False, - file_path: PurePath | str | None = None, - raw: bool = False, - pool: str | None = None, - cfg_path: str | None = None, - map_index: int = -1, - ) -> list[str]: - """ - Generate the shell command required to execute this task instance. - - :param dag_id: DAG ID - :param task_id: Task ID - :param run_id: The run_id of this task's DagRun - :param mark_success: Whether to mark the task as successful - :param ignore_all_deps: Ignore all ignorable dependencies. - Overrides the other ignore_* parameters. - :param ignore_depends_on_past: Ignore depends_on_past parameter of DAGs - (e.g. for Backfills) - :param wait_for_past_depends_before_skipping: Wait for past depends before marking the ti as skipped - :param ignore_task_deps: Ignore task-specific dependencies such as depends_on_past - and trigger rule - :param ignore_ti_state: Ignore the task instance's previous failure/success - :param local: Whether to run the task locally - :param file_path: path to the file containing the DAG definition - :param raw: raw mode (needs more details) - :param pool: the Airflow pool that the task should run in - :param cfg_path: the Path to the configuration file - :return: shell command that can be used to run the task instance - """ - cmd = ["airflow", "tasks", "run", dag_id, task_id, run_id] - if mark_success: - cmd.extend(["--mark-success"]) - if ignore_all_deps: - cmd.extend(["--ignore-all-dependencies"]) - if ignore_task_deps: - cmd.extend(["--ignore-dependencies"]) - if ignore_depends_on_past: - cmd.extend(["--depends-on-past", "ignore"]) - elif wait_for_past_depends_before_skipping: - cmd.extend(["--depends-on-past", "wait"]) - if ignore_ti_state: - cmd.extend(["--force"]) - if local: - cmd.extend(["--local"]) - if pool: - cmd.extend(["--pool", pool]) - if raw: - cmd.extend(["--raw"]) - if file_path: - cmd.extend(["--subdir", os.fspath(file_path)]) - if cfg_path: - cmd.extend(["--cfg-path", cfg_path]) - if map_index != -1: - cmd.extend(["--map-index", str(map_index)]) - return cmd - @property def log_url(self) -> str: """Log URL for TaskInstance.""" diff --git a/airflow-core/tests/unit/executors/test_base_executor.py b/airflow-core/tests/unit/executors/test_base_executor.py index be4a7a0d93a21..575201b2f4498 100644 --- a/airflow-core/tests/unit/executors/test_base_executor.py +++ b/airflow-core/tests/unit/executors/test_base_executor.py @@ -27,6 +27,7 @@ from airflow.cli.cli_config import DefaultHelpParser, GroupCommand from airflow.cli.cli_parser import AirflowHelpFormatter +from airflow.executors import workloads from airflow.executors.base_executor import BaseExecutor, RunningRetryAttemptType from airflow.executors.local_executor import LocalExecutor from airflow.models.baseoperator import BaseOperator @@ -183,149 +184,61 @@ def test_try_adopt_task_instances(dag_maker): assert BaseExecutor().try_adopt_task_instances(tis) == tis -def enqueue_tasks(executor, dagrun): - for task_instance in dagrun.task_instances: - executor.queue_command(task_instance, ["airflow"]) - - def setup_trigger_tasks(dag_maker, parallelism=None): dagrun = setup_dagrun(dag_maker) if parallelism: executor = BaseExecutor(parallelism=parallelism) else: executor = BaseExecutor() - executor.execute_async = mock.Mock() - enqueue_tasks(executor, dagrun) - return executor, dagrun + executor._process_workloads = mock.Mock() -@pytest.mark.db_test -@pytest.mark.parametrize("open_slots", [1, 2, 3]) -def test_trigger_queued_tasks(dag_maker, open_slots): - executor_parallelism = 10 - executor, dagrun = setup_trigger_tasks(dag_maker, executor_parallelism) - num_tasks = len(dagrun.task_instances) - - # All tasks are queued in setup method - assert executor.slots_occupied == num_tasks - assert executor.slots_available == executor_parallelism - num_tasks - assert len(executor.queued_tasks) == num_tasks - assert len(executor.running) == 0 - executor.trigger_tasks(open_slots) - assert executor.slots_available == executor_parallelism - num_tasks - assert executor.slots_occupied == num_tasks - assert len(executor.queued_tasks) == num_tasks - open_slots - # Only open_slots number of tasks are allowed through to running - assert len(executor.running) == open_slots - assert executor.execute_async.call_count == open_slots + for task_instance in dagrun.task_instances: + workload = workloads.ExecuteTask.make(task_instance) + executor.queued_tasks[task_instance.key] = workload + + return executor, dagrun @pytest.mark.db_test -@pytest.mark.parametrize( - "can_try_num, change_state_num, second_exec", - [ - (2, 3, False), - (3, 3, True), - (4, 3, True), - ], -) -@mock.patch("airflow.executors.base_executor.RunningRetryAttemptType.can_try_again") -def test_trigger_running_tasks(can_try_mock, dag_maker, can_try_num, change_state_num, second_exec): - can_try_mock.side_effect = [True for _ in range(can_try_num)] + [False] +def test_trigger_queued_tasks(dag_maker): + """Test that trigger_tasks() calls _process_workloads() when there are queued workloads.""" executor, dagrun = setup_trigger_tasks(dag_maker) - open_slots = 100 - executor.trigger_tasks(open_slots) - expected_calls = len(dagrun.task_instances) # initially `execute_async` called for each task - assert executor.execute_async.call_count == expected_calls - # All the tasks are now "running", so while we enqueue them again here, - # they won't be executed again until the executor has been notified of a state change. - ti = dagrun.task_instances[0] - assert ti.key in executor.running - assert ti.key not in executor.queued_tasks - executor.queue_command(ti, ["airflow"]) - - # this is the problem we're dealing with: ti.key both queued and running - assert ti.key in executor.queued_tasks - assert ti.key in executor.running - assert len(executor.attempts) == 0 - executor.trigger_tasks(open_slots) - - # first trigger call after queueing again creates an attempt object - assert len(executor.attempts) == 1 - assert ti.key in executor.attempts - - for attempt in range(2, change_state_num + 2): - executor.trigger_tasks(open_slots) - if attempt <= min(can_try_num, change_state_num): - assert ti.key in executor.queued_tasks - assert ti.key in executor.running - # On the configured attempt, we notify the executor that the task has succeeded. - if attempt == change_state_num: - executor.change_state(ti.key, State.SUCCESS) - assert ti.key not in executor.running - # retry was ok when state changed, ti.key will be in running (for the second time) - if can_try_num >= change_state_num: - assert ti.key in executor.running - else: # otherwise, it won't be - assert ti.key not in executor.running - # either way, ti.key not in queued -- it was either removed because never left running - # or it was moved out when run 2nd time - assert ti.key not in executor.queued_tasks - assert not executor.attempts - - # we expect one more "execute_async" if TI was marked successful - # this would move it out of running set and free the queued TI to be executed again - if second_exec is True: - expected_calls += 1 - - assert executor.execute_async.call_count == expected_calls + # Verify tasks are queued + assert len(executor.queued_tasks) == 3 + # Call trigger_tasks with enough slots + executor.trigger_tasks(open_slots=10) -@pytest.mark.db_test -def test_validate_airflow_tasks_run_command(dag_maker): - dagrun = setup_dagrun(dag_maker) - tis = dagrun.task_instances - print(f"command: {tis[0].command_as_list()}") - dag_id, task_id = BaseExecutor.validate_airflow_tasks_run_command(tis[0].command_as_list()) - print(f"dag_id: {dag_id}, task_id: {task_id}") - assert dag_id == dagrun.dag_id - assert task_id == tis[0].task_id + executor._process_workloads.assert_called_once() + + # Verify it was called with the expected workloads + call_args = executor._process_workloads.call_args[0][0] + assert len(call_args) == 3 @pytest.mark.db_test -@mock.patch( - "airflow.models.taskinstance.TaskInstance.generate_command", - return_value=["airflow", "tasks", "run", "--test_dag", "--test_task"], -) -def test_validate_airflow_tasks_run_command_with_complete_forloop(generate_command_mock, dag_maker): - dagrun = setup_dagrun(dag_maker) - tis = dagrun.task_instances - dag_id, task_id = BaseExecutor.validate_airflow_tasks_run_command(tis[0].command_as_list()) - assert dag_id is None - assert task_id is None +def test_trigger_running_tasks(dag_maker): + """Test that trigger_tasks() works when tasks are re-queued.""" + executor, dagrun = setup_trigger_tasks(dag_maker) + executor.trigger_tasks(open_slots=10) + executor._process_workloads.assert_called_once() -@pytest.mark.db_test -@mock.patch( - "airflow.models.taskinstance.TaskInstance.generate_command", return_value=["airflow", "task", "run"] -) -def test_invalid_airflow_tasks_run_command(generate_command_mock, dag_maker): - dagrun = setup_dagrun(dag_maker) - tis = dagrun.task_instances - with pytest.raises(ValueError): - BaseExecutor.validate_airflow_tasks_run_command(tis[0].command_as_list()) + # Reset mock for second call + executor._process_workloads.reset_mock() + # Re-queue one task (simulates retry scenario) + ti = dagrun.task_instances[0] -@pytest.mark.db_test -@mock.patch( - "airflow.models.taskinstance.TaskInstance.generate_command", return_value=["airflow", "tasks", "run"] -) -def test_empty_airflow_tasks_run_command(generate_command_mock, dag_maker): - dagrun = setup_dagrun(dag_maker) - tis = dagrun.task_instances - dag_id, task_id = BaseExecutor.validate_airflow_tasks_run_command(tis[0].command_as_list()) - assert dag_id is None, task_id is None + workload = workloads.ExecuteTask.make(ti) + executor.queued_tasks[ti.key] = workload + + executor.trigger_tasks(open_slots=10) + + # Verify _process_workloads was called again + executor._process_workloads.assert_called_once() def test_debug_dump(caplog): diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py b/airflow-core/tests/unit/jobs/test_scheduler_job.py index c1f926cfdbed0..0bb42c18594c2 100644 --- a/airflow-core/tests/unit/jobs/test_scheduler_job.py +++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py @@ -197,12 +197,17 @@ def set_instance_attrs(self) -> Generator: @pytest.fixture def mock_executors(self): + mock_jwt_generator = MagicMock() + mock_jwt_generator.generate.return_value = "mock-token" + default_executor = mock.MagicMock(name="DefaultExecutor", slots_available=8, slots_occupied=0) default_executor.name = ExecutorName(alias="default_exec", module_path="default.exec.module.path") + default_executor.jwt_generator = mock_jwt_generator second_executor = mock.MagicMock(name="SeconadaryExecutor", slots_available=8, slots_occupied=0) second_executor.name = ExecutorName(alias="secondary_exec", module_path="secondary.exec.module.path") + second_executor.jwt_generator = mock_jwt_generator - # TODO: Task-SDK Make it look like a bound method. Needed until we remove the old queue_command + # TODO: Task-SDK Make it look like a bound method. Needed until we remove the old queue_workload # interface from executors default_executor.queue_workload.__func__ = BaseExecutor.queue_workload second_executor.queue_workload.__func__ = BaseExecutor.queue_workload @@ -1544,12 +1549,12 @@ def test_enqueue_task_instances_with_queued_state(self, dag_maker, session): dr1 = dag_maker.create_dagrun() ti1 = dr1.get_task_instance(task1.task_id, session) - with patch.object(BaseExecutor, "queue_command") as mock_queue_command: + with patch.object(BaseExecutor, "queue_workload") as mock_queue_workload: self.job_runner._enqueue_task_instances_with_queued_state( [ti1], executor=scheduler_job.executor, session=session ) - assert mock_queue_command.called + assert mock_queue_workload.called session.rollback() @pytest.mark.parametrize("state", [State.FAILED, State.SUCCESS]) @@ -1570,14 +1575,14 @@ def test_enqueue_task_instances_sets_ti_state_to_None_if_dagrun_in_finish_state( session.merge(ti) session.commit() - with patch.object(BaseExecutor, "queue_command") as mock_queue_command: + with patch.object(BaseExecutor, "queue_workload") as mock_queue_workload: self.job_runner._enqueue_task_instances_with_queued_state( [ti], executor=scheduler_job.executor, session=session ) session.flush() ti.refresh_from_db(session=session) assert ti.state == State.NONE - mock_queue_command.assert_not_called() + mock_queue_workload.assert_not_called() @pytest.mark.parametrize( "task1_exec, task2_exec", diff --git a/devel-common/src/tests_common/test_utils/mock_executor.py b/devel-common/src/tests_common/test_utils/mock_executor.py index 4261ae80cd0ce..4e95ed3a4eea7 100644 --- a/devel-common/src/tests_common/test_utils/mock_executor.py +++ b/devel-common/src/tests_common/test_utils/mock_executor.py @@ -18,14 +18,20 @@ from __future__ import annotations from collections import defaultdict +from collections.abc import Sequence +from typing import TYPE_CHECKING from unittest.mock import MagicMock from airflow.executors.base_executor import BaseExecutor from airflow.executors.executor_utils import ExecutorName +from airflow.models.taskinstance import TaskInstance from airflow.models.taskinstancekey import TaskInstanceKey from airflow.utils.session import create_session from airflow.utils.state import State +if TYPE_CHECKING: + from airflow.executors import workloads + class MockExecutor(BaseExecutor): """TestExecutor is used for unit testing purposes.""" @@ -50,11 +56,23 @@ def __init__(self, do_update=True, *args, **kwargs): # So we should pass self.success instead of lambda. self.mock_task_results = defaultdict(self.success) + # Mock JWT generator for token generation + mock_jwt_generator = MagicMock() + mock_jwt_generator.generate.return_value = "mock-token" + + self.jwt_generator = mock_jwt_generator + super().__init__(*args, **kwargs) def success(self): return State.SUCCESS + def _process_workloads(self, workload_list: Sequence[workloads.All]) -> None: + """Process the given workloads - mock implementation.""" + # For mock executor, we don't actually process the workloads, + # they get processed in heartbeat() + pass + def heartbeat(self): if not self.do_update: return @@ -65,17 +83,25 @@ def heartbeat(self): # Create a stable/predictable sort order for events in self.history # for tests! def sort_by(item): - key, val = item + key, workload = item (dag_id, task_id, date, try_number, map_index) = key - (_, prio, _, _) = val + # For workloads, use the task instance priority if available + prio = getattr(workload.ti, "priority_weight", 1) if hasattr(workload, "ti") else 1 # Sort by priority (DESC), then date,task, try return -prio, date, dag_id, task_id, map_index, try_number open_slots = self.parallelism - len(self.running) sorted_queue = sorted(self.queued_tasks.items(), key=sort_by) - for key, (_, _, _, ti) in sorted_queue[:open_slots]: + for key, workload in sorted_queue[:open_slots]: self.queued_tasks.pop(key) state = self.mock_task_results[key] + ti = TaskInstance.get_task_instance( + task_id=workload.ti.task_id, + run_id=workload.ti.run_id, + dag_id=workload.ti.dag_id, + map_index=workload.ti.map_index, + lock_for_update=True, + ) ti.set_state(state, session=session) self.change_state(key, state) session.flush() diff --git a/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/lambda_executor.py b/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/lambda_executor.py index 3f477095f4474..2e3387893625e 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/lambda_executor.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/lambda_executor.py @@ -203,7 +203,7 @@ def queue_workload(self, workload: workloads.All, session: Session | None) -> No ti = workload.ti self.queued_tasks[ti.key] = workload - def _process_workloads(self, workloads: list[workloads.All]) -> None: + def _process_workloads(self, workloads: Sequence[workloads.All]) -> None: from airflow.executors.workloads import ExecuteTask for w in workloads: diff --git a/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py b/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py index 8625499ab1163..c28d22ee2b36b 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py @@ -131,7 +131,7 @@ def queue_workload(self, workload: workloads.All, session: Session | None) -> No ti = workload.ti self.queued_tasks[ti.key] = workload - def _process_workloads(self, workloads: list[workloads.All]) -> None: + def _process_workloads(self, workloads: Sequence[workloads.All]) -> None: from airflow.executors.workloads import ExecuteTask # Airflow V3 version diff --git a/providers/amazon/tests/unit/amazon/aws/executors/batch/test_batch_executor.py b/providers/amazon/tests/unit/amazon/aws/executors/batch/test_batch_executor.py index 0561bc45357d7..31c19edd05c02 100644 --- a/providers/amazon/tests/unit/amazon/aws/executors/batch/test_batch_executor.py +++ b/providers/amazon/tests/unit/amazon/aws/executors/batch/test_batch_executor.py @@ -607,6 +607,7 @@ def _mock_sync( } executor.batch.describe_jobs.return_value = {"jobs": [after_batch_job]} + @pytest.mark.skip(reason="Adopting task instances hasn't been ported over to Airflow 3 yet") def test_try_adopt_task_instances(self, mock_executor): """Test that executor can adopt orphaned task instances from a SchedulerJob shutdown event.""" mock_executor.batch.describe_jobs.return_value = { diff --git a/providers/amazon/tests/unit/amazon/aws/executors/ecs/test_ecs_executor.py b/providers/amazon/tests/unit/amazon/aws/executors/ecs/test_ecs_executor.py index 70cedadab341d..a29cb883a7b42 100644 --- a/providers/amazon/tests/unit/amazon/aws/executors/ecs/test_ecs_executor.py +++ b/providers/amazon/tests/unit/amazon/aws/executors/ecs/test_ecs_executor.py @@ -1249,6 +1249,7 @@ def test_update_running_tasks_failed(self, mock_executor, caplog): "test failure" in caplog.messages[0] ) + @pytest.mark.skip(reason="Adopting task instances hasn't been ported over to Airflow 3 yet") def test_try_adopt_task_instances(self, mock_executor): """Test that executor can adopt orphaned task instances from a SchedulerJob shutdown event.""" mock_executor.ecs.describe_tasks.return_value = { diff --git a/providers/celery/src/airflow/providers/celery/executors/celery_executor.py b/providers/celery/src/airflow/providers/celery/executors/celery_executor.py index cfb743eed47ff..0a5aee9c0308b 100644 --- a/providers/celery/src/airflow/providers/celery/executors/celery_executor.py +++ b/providers/celery/src/airflow/providers/celery/executors/celery_executor.py @@ -66,14 +66,14 @@ if TYPE_CHECKING: import argparse + from collections.abc import Sequence from sqlalchemy.orm import Session from airflow.executors import workloads - from airflow.executors.base_executor import TaskTuple from airflow.models.taskinstance import TaskInstance from airflow.models.taskinstancekey import TaskInstanceKey - from airflow.providers.celery.executors.celery_executor_utils import TaskInstanceInCelery + from airflow.providers.celery.executors.celery_executor_utils import TaskInstanceInCelery, TaskTuple # PEP562 @@ -256,7 +256,7 @@ def _num_tasks_per_send_process(self, to_send_count: int) -> int: """ return max(1, math.ceil(to_send_count / self._sync_parallelism)) - def _process_tasks(self, task_tuples: list[TaskTuple]) -> None: + def _process_tasks(self, task_tuples: Sequence[TaskTuple]) -> None: # Airflow V2 version from airflow.providers.celery.executors.celery_executor_utils import execute_command @@ -264,18 +264,18 @@ def _process_tasks(self, task_tuples: list[TaskTuple]) -> None: self._send_tasks(task_tuples_to_send) - def _process_workloads(self, input: list[workloads.All]) -> None: + def _process_workloads(self, workloads: Sequence[workloads.All]) -> None: # Airflow V3 version -- have to delay imports until we know we are on v3 - from airflow.executors import workloads + from airflow.executors.workloads import ExecuteTask from airflow.providers.celery.executors.celery_executor_utils import execute_workload tasks = [ (workload.ti.key, workload, workload.ti.queue, execute_workload) - for workload in input - if isinstance(workload, workloads.ExecuteTask) + for workload in workloads + if isinstance(workload, ExecuteTask) ] - if len(tasks) != len(input): - invalid = list(workload for workload in input if not isinstance(workload, workloads.ExecuteTask)) + if len(tasks) != len(workloads): + invalid = list(workload for workload in workloads if not isinstance(workload, ExecuteTask)) raise ValueError(f"{type(self)}._process_workloads cannot handle {invalid}") self._send_tasks(tasks) diff --git a/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py b/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py index 4c16c90a83fdc..6bef6ae9e3e81 100644 --- a/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py +++ b/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py @@ -30,7 +30,7 @@ import sys import traceback import warnings -from collections.abc import Mapping, MutableMapping +from collections.abc import Mapping, MutableMapping, Sequence from concurrent.futures import ProcessPoolExecutor from typing import TYPE_CHECKING, Any, Optional, Union @@ -63,16 +63,20 @@ from celery.result import AsyncResult from airflow.executors import workloads - from airflow.executors.base_executor import CommandType, EventBufferValueType + from airflow.executors.base_executor import EventBufferValueType from airflow.models.taskinstance import TaskInstanceKey from airflow.typing_compat import TypeAlias # We can't use `if AIRFLOW_V_3_0_PLUS` conditions in type checks, so unfortunately we just have to define # the type as the union of both kinds + CommandType = Sequence[str] + TaskInstanceInCelery: TypeAlias = tuple[ TaskInstanceKey, Union[workloads.All, CommandType], Optional[str], Task ] + TaskTuple = tuple[TaskInstanceKey, CommandType, Optional[str], Optional[Any]] + OPERATION_TIMEOUT = conf.getfloat("celery", "operation_timeout") # Make it constant for unit test. @@ -181,7 +185,7 @@ def execute_workload(input: str) -> None: @app.task def execute_command(command_to_exec: CommandType) -> None: """Execute command.""" - dag_id, task_id = BaseExecutor.validate_airflow_tasks_run_command(command_to_exec) + dag_id, task_id = BaseExecutor.validate_airflow_tasks_run_command(command_to_exec) # type: ignore[attr-defined] celery_task_id = app.current_task.request.id log.info("[%s] Executing command in Celery: %s", celery_task_id, command_to_exec) with _airflow_parsing_context_manager(dag_id=dag_id, task_id=task_id): diff --git a/providers/celery/src/airflow/providers/celery/executors/celery_kubernetes_executor.py b/providers/celery/src/airflow/providers/celery/executors/celery_kubernetes_executor.py index 8d63150f5448f..8e895f6ed9044 100644 --- a/providers/celery/src/airflow/providers/celery/executors/celery_kubernetes_executor.py +++ b/providers/celery/src/airflow/providers/celery/executors/celery_kubernetes_executor.py @@ -19,7 +19,7 @@ from collections.abc import Sequence from functools import cached_property -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from deprecated import deprecated @@ -38,14 +38,12 @@ if TYPE_CHECKING: from airflow.callbacks.base_callback_sink import BaseCallbackSink from airflow.callbacks.callback_requests import CallbackRequest - from airflow.executors.base_executor import ( - CommandType, - EventBufferValueType, - QueuedTaskInstanceType, - ) + from airflow.executors.base_executor import EventBufferValueType from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance from airflow.models.taskinstancekey import TaskInstanceKey + CommandType = Sequence[str] + class CeleryKubernetesExecutor(BaseExecutor): """ @@ -93,7 +91,7 @@ def _task_event_logs(self, value): """Not implemented for hybrid executors.""" @property - def queued_tasks(self) -> dict[TaskInstanceKey, QueuedTaskInstanceType]: + def queued_tasks(self) -> dict[TaskInstanceKey, Any]: """Return queued tasks from celery and kubernetes executor.""" queued_tasks = self.celery_executor.queued_tasks.copy() queued_tasks.update(self.kubernetes_executor.queued_tasks) # type: ignore[arg-type] @@ -155,7 +153,7 @@ def queue_command( """Queues command via celery or kubernetes executor.""" executor = self._router(task_instance) self.log.debug("Using executor: %s for %s", executor.__class__.__name__, task_instance.key) - executor.queue_command(task_instance, command, priority, queue) + executor.queue_command(task_instance, command, priority, queue) # type: ignore[union-attr] def queue_task_instance( self, @@ -182,7 +180,7 @@ def queue_task_instance( if not hasattr(task_instance, "pickle_id"): del kwargs["pickle_id"] - executor.queue_task_instance( + executor.queue_task_instance( # type: ignore[union-attr] task_instance=task_instance, mark_success=mark_success, ignore_all_deps=ignore_all_deps, diff --git a/providers/celery/tests/unit/celery/executors/test_celery_kubernetes_executor.py b/providers/celery/tests/unit/celery/executors/test_celery_kubernetes_executor.py index 267b0fa78fd87..cadf718ec1a7e 100644 --- a/providers/celery/tests/unit/celery/executors/test_celery_kubernetes_executor.py +++ b/providers/celery/tests/unit/celery/executors/test_celery_kubernetes_executor.py @@ -88,6 +88,7 @@ def test_start(self): celery_executor_mock.start.assert_called() k8s_executor_mock.start.assert_called() + @pytest.mark.skipif(AIRFLOW_V_3_0_PLUS, reason="Airflow 3 doesn't have queue_command anymore") @pytest.mark.parametrize("test_queue", ["any-other-queue", KUBERNETES_QUEUE]) @mock.patch.object(CeleryExecutor, "queue_command") @mock.patch.object(KubernetesExecutor, "queue_command") diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/cli/kubernetes_command.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/cli/kubernetes_command.py index d475fba6817c4..46e41e348f145 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/cli/kubernetes_command.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/cli/kubernetes_command.py @@ -26,16 +26,17 @@ from kubernetes.client.api_client import ApiClient from kubernetes.client.rest import ApiException -from airflow.models import DagRun, TaskInstance +from airflow.models import DagModel, DagRun, TaskInstance from airflow.providers.cncf.kubernetes import pod_generator from airflow.providers.cncf.kubernetes.executors.kubernetes_executor import KubeConfig from airflow.providers.cncf.kubernetes.kube_client import get_kube_client from airflow.providers.cncf.kubernetes.kubernetes_helper_functions import create_unique_id -from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator +from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator, generate_pod_command_args from airflow.providers.cncf.kubernetes.version_compat import AIRFLOW_V_3_0_PLUS from airflow.utils import cli as cli_utils, yaml from airflow.utils.cli import get_dag from airflow.utils.providers_configuration_loader import providers_configuration_loaded +from airflow.utils.types import DagRunType @cli_utils.action_cli @@ -48,14 +49,28 @@ def generate_pod_yaml(args): else: dag = get_dag(subdir=args.subdir, dag_id=args.dag_id) yaml_output_path = args.output_path + + dm = DagModel(dag_id=dag.dag_id) + if AIRFLOW_V_3_0_PLUS: dr = DagRun(dag.dag_id, logical_date=logical_date) + dr.run_id = DagRun.generate_run_id( + run_type=DagRunType.MANUAL, logical_date=logical_date, run_after=logical_date + ) + dm.bundle_name = args.bundle_name if args.bundle_name else "default" + dm.relative_fileloc = dag.relative_fileloc else: dr = DagRun(dag.dag_id, execution_date=logical_date) + dr.run_id = DagRun.generate_run_id(run_type=DagRunType.MANUAL, execution_date=logical_date) + kube_config = KubeConfig() + for task in dag.tasks: - ti = TaskInstance(task, None) + ti = TaskInstance(task, run_id=dr.run_id) ti.dag_run = dr + ti.dag_model = dm + + command_args = generate_pod_command_args(ti) pod = PodGenerator.construct_pod( dag_id=args.dag_id, task_id=ti.task_id, @@ -63,7 +78,7 @@ def generate_pod_yaml(args): try_number=ti.try_number, kube_image=kube_config.kube_image, date=ti.logical_date if AIRFLOW_V_3_0_PLUS else ti.execution_date, - args=ti.command_as_list(), + args=command_args, pod_override_object=PodGenerator.from_obj(ti.executor_config), scheduler_job_id="worker-config", namespace=kube_config.executor_namespace, diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py index d047886061d76..37908c176e616 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py @@ -76,13 +76,13 @@ if TYPE_CHECKING: import argparse + from collections.abc import Sequence from kubernetes import client from kubernetes.client import models as k8s from sqlalchemy.orm import Session from airflow.executors import workloads - from airflow.executors.base_executor import CommandType from airflow.models.taskinstance import TaskInstance from airflow.models.taskinstancekey import TaskInstanceKey from airflow.providers.cncf.kubernetes.executors.kubernetes_executor_types import ( @@ -254,7 +254,7 @@ def start(self) -> None: def execute_async( self, key: TaskInstanceKey, - command: CommandType, + command: Any, queue: str | None = None, executor_config: Any | None = None, ) -> None: @@ -292,7 +292,7 @@ def queue_workload(self, workload: workloads.All, session: Session | None) -> No ti = workload.ti self.queued_tasks[ti.key] = workload - def _process_workloads(self, workloads: list[workloads.All]) -> None: + def _process_workloads(self, workloads: Sequence[workloads.All]) -> None: from airflow.executors.workloads import ExecuteTask # Airflow V3 version diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_types.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_types.py index bd4274025a8a9..077ac422d4228 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_types.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_types.py @@ -20,10 +20,14 @@ ADOPTED = "adopted" if TYPE_CHECKING: - from airflow.executors.base_executor import CommandType + from collections.abc import Sequence + from airflow.models.taskinstance import TaskInstanceKey from airflow.utils.state import TaskInstanceState + # TODO: Remove after Airflow 2 support is removed + CommandType = Sequence[str] + # TaskInstance key, command, configuration, pod_template_file KubernetesJobType = tuple[TaskInstanceKey, CommandType, Any, Optional[str]] diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_utils.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_utils.py index b6c2b097cdf58..b7e03459d6089 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_utils.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_utils.py @@ -41,7 +41,7 @@ annotations_to_key, create_unique_id, ) -from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator +from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator, workload_to_command_args from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.singleton import Singleton from airflow.utils.state import TaskInstanceState @@ -387,20 +387,12 @@ def run_next(self, next_job: KubernetesJobType) -> None: key, command, kube_executor_config, pod_template_file = next_job dag_id, task_id, run_id, try_number, map_index = key - ser_input = "" if len(command) == 1: from airflow.executors.workloads import ExecuteTask if isinstance(command[0], ExecuteTask): workload = command[0] - ser_input = workload.model_dump_json() - command = [ - "python", - "-m", - "airflow.sdk.execution_time.execute_workload", - "--json-string", - ser_input, - ] + command = workload_to_command_args(workload) else: raise ValueError( f"KubernetesExecutor doesn't know how to handle workload of type: {type(command[0])}" diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/local_kubernetes_executor.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/local_kubernetes_executor.py index 5be33483b31c3..0eb6e5a1f2157 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/local_kubernetes_executor.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/local_kubernetes_executor.py @@ -18,7 +18,7 @@ from __future__ import annotations from collections.abc import Sequence -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from deprecated import deprecated @@ -30,14 +30,12 @@ if TYPE_CHECKING: from airflow.callbacks.base_callback_sink import BaseCallbackSink from airflow.callbacks.callback_requests import CallbackRequest - from airflow.executors.base_executor import ( - CommandType, - EventBufferValueType, - QueuedTaskInstanceType, - ) + from airflow.executors.base_executor import EventBufferValueType from airflow.executors.local_executor import LocalExecutor from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance, TaskInstanceKey + CommandType = Sequence[str] + class LocalKubernetesExecutor(BaseExecutor): """ @@ -81,7 +79,7 @@ def _task_event_logs(self, value): """Not implemented for hybrid executors.""" @property - def queued_tasks(self) -> dict[TaskInstanceKey, QueuedTaskInstanceType]: + def queued_tasks(self) -> dict[TaskInstanceKey, Any]: """Return queued tasks from local and kubernetes executor.""" queued_tasks = self.local_executor.queued_tasks.copy() # TODO: fix this, there is misalignment between the types of queued_tasks so it is likely wrong @@ -145,7 +143,7 @@ def queue_command( """Queues command via local or kubernetes executor.""" executor = self._router(task_instance) self.log.debug("Using executor: %s for %s", executor.__class__.__name__, task_instance.key) - executor.queue_command(task_instance, command, priority, queue) + executor.queue_command(task_instance, command, priority, queue) # type: ignore[union-attr] def queue_task_instance( self, @@ -171,7 +169,7 @@ def queue_task_instance( if not hasattr(task_instance, "pickle_id"): del kwargs["pickle_id"] - executor.queue_task_instance( + executor.queue_task_instance( # type: ignore[union-attr] task_instance=task_instance, mark_success=mark_success, ignore_all_deps=ignore_all_deps, diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/pod_generator.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/pod_generator.py index a0cc2b8cf8b5e..478a35045cc74 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/pod_generator.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/pod_generator.py @@ -46,6 +46,7 @@ POD_NAME_MAX_LENGTH, add_unique_suffix, ) +from airflow.providers.cncf.kubernetes.version_compat import AIRFLOW_V_3_0_PLUS from airflow.utils import yaml from airflow.utils.hashlib_wrapper import md5 from airflow.version import version as airflow_version @@ -53,11 +54,49 @@ if TYPE_CHECKING: import datetime + from airflow.executors import workloads + from airflow.models.taskinstance import TaskInstance + log = logging.getLogger(__name__) MAX_LABEL_LEN = 63 +def workload_to_command_args(workload: workloads.ExecuteTask) -> list[str]: + """ + Convert a workload object to Task SDK command arguments. + + :param workload: The ExecuteTask workload to convert + :return: List of command arguments for the Task SDK + """ + ser_input = workload.model_dump_json() + return [ + "python", + "-m", + "airflow.sdk.execution_time.execute_workload", + "--json-string", + ser_input, + ] + + +def generate_pod_command_args(task_instance: TaskInstance) -> list[str]: + """ + Generate command arguments for a ``TaskInstance`` to be used in a Kubernetes pod. + + This function handles backwards compatibility between Airflow 2.x and 3.x: + - In Airflow 2.x: Uses the existing ``command_as_list()`` method + - In Airflow 3.x: Uses the Task SDK workload approach with serialized workload + """ + if AIRFLOW_V_3_0_PLUS: + # In Airflow 3+, use the Task SDK workload approach + from airflow.executors import workloads + + workload = workloads.ExecuteTask.make(task_instance) + return workload_to_command_args(workload) + # In Airflow 2.x, use the existing method + return task_instance.command_as_list() + + def make_safe_label_value(string: str) -> str: """ Normalize a provided label to be of valid length and characters. diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/template_rendering.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/template_rendering.py index 7f2cd83f33f4c..c16ab6b0cc345 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/template_rendering.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/template_rendering.py @@ -25,7 +25,7 @@ from airflow.exceptions import AirflowException from airflow.providers.cncf.kubernetes.kube_config import KubeConfig from airflow.providers.cncf.kubernetes.kubernetes_helper_functions import create_unique_id -from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator +from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator, generate_pod_command_args from airflow.utils.session import NEW_SESSION, provide_session if TYPE_CHECKING: @@ -43,6 +43,10 @@ def render_k8s_pod_yaml(task_instance: TaskInstance) -> dict | None: # If no such pod_template_file override was passed, we can simply render # The pod spec using the default template. pod_template_file = kube_config.pod_template_file + + # Generate command args using shared utility function + command_args = generate_pod_command_args(task_instance) + pod = PodGenerator.construct_pod( dag_id=task_instance.dag_id, run_id=task_instance.run_id, @@ -52,7 +56,7 @@ def render_k8s_pod_yaml(task_instance: TaskInstance) -> dict | None: pod_id=create_unique_id(task_instance.dag_id, task_instance.task_id), try_number=task_instance.try_number, kube_image=kube_config.kube_image, - args=task_instance.command_as_list(), + args=command_args, pod_override_object=PodGenerator.from_obj(task_instance.executor_config), scheduler_job_id="0", namespace=kube_config.executor_namespace, diff --git a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/test_template_rendering.py b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/test_template_rendering.py index d7e34d4125910..1209058f10fd3 100644 --- a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/test_template_rendering.py +++ b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/test_template_rendering.py @@ -48,6 +48,29 @@ def test_render_k8s_pod_yaml(pod_mutation_hook, create_task_instance): logical_date=DEFAULT_DATE, ) + if AIRFLOW_V_3_0_PLUS: + from airflow.executors import workloads + + workload = workloads.ExecuteTask.make(ti) + rendered_args = [ + "python", + "-m", + "airflow.sdk.execution_time.execute_workload", + "--json-string", + workload.model_dump_json(), + ] + else: + rendered_args = [ + "airflow", + "tasks", + "run", + "test_render_k8s_pod_yaml", + "op1", + "test_run_id", + "--subdir", + mock.ANY, + ] + expected_pod_spec = { "metadata": { "annotations": { @@ -71,16 +94,7 @@ def test_render_k8s_pod_yaml(pod_mutation_hook, create_task_instance): "spec": { "containers": [ { - "args": [ - "airflow", - "tasks", - "run", - "test_render_k8s_pod_yaml", - "op1", - "test_run_id", - "--subdir", - mock.ANY, - ], + "args": rendered_args, "name": "base", "env": [{"name": "AIRFLOW_IS_K8S_EXECUTOR_POD", "value": "True"}], } diff --git a/providers/edge3/src/airflow/providers/edge3/executors/edge_executor.py b/providers/edge3/src/airflow/providers/edge3/executors/edge_executor.py index a2ebe43565299..b203f12dcc5a7 100644 --- a/providers/edge3/src/airflow/providers/edge3/executors/edge_executor.py +++ b/providers/edge3/src/airflow/providers/edge3/executors/edge_executor.py @@ -47,9 +47,10 @@ from sqlalchemy.engine.base import Engine - from airflow.executors.base_executor import CommandType from airflow.models.taskinstancekey import TaskInstanceKey + # TODO: Airflow 2 type hints; remove when Airflow 2 support is removed + CommandType = Sequence[str] # Task tuple to send to be executed TaskTuple = tuple[TaskInstanceKey, CommandType, Optional[str], Optional[Any]] @@ -108,7 +109,7 @@ def _process_tasks(self, task_tuples: list[TaskTuple]) -> None: Store queued_tasks in own var to be able to access this in execute_async function. """ self.edge_queued_tasks = deepcopy(self.queued_tasks) - super()._process_tasks(task_tuples) + super()._process_tasks(task_tuples) # type: ignore[misc] @provide_session def execute_async( @@ -122,10 +123,11 @@ def execute_async( """Execute asynchronously. Airflow 2.10 entry point to execute a task.""" # Use of a temponary trick to get task instance, will be changed with Airflow 3.0.0 # code works together with _process_tasks overwrite to get task instance. - task_instance = self.edge_queued_tasks[key][3] # TaskInstance in fourth element + # TaskInstance in fourth element + task_instance = self.edge_queued_tasks[key][3] # type: ignore[index] del self.edge_queued_tasks[key] - self.validate_airflow_tasks_run_command(command) + self.validate_airflow_tasks_run_command(command) # type: ignore[attr-defined] session.add( EdgeJobModel( dag_id=key.dag_id, diff --git a/providers/edge3/tests/unit/edge3/executors/test_edge_executor.py b/providers/edge3/tests/unit/edge3/executors/test_edge_executor.py index 1f13e133f5571..02ab7582a54cd 100644 --- a/providers/edge3/tests/unit/edge3/executors/test_edge_executor.py +++ b/providers/edge3/tests/unit/edge3/executors/test_edge_executor.py @@ -57,12 +57,14 @@ def get_test_executor(self, pool_slots=1): return (executor, key) + @pytest.mark.skipif(AIRFLOW_V_3_0_PLUS, reason="_process_tasks is not used in Airflow 3.0+") def test__process_tasks_bad_command(self): executor, key = self.get_test_executor() task_tuple = (key, ["hello", "world"], None, None) with pytest.raises(ValueError): executor._process_tasks([task_tuple]) + @pytest.mark.skipif(AIRFLOW_V_3_0_PLUS, reason="_process_tasks is not used in Airflow 3.0+") @pytest.mark.parametrize( "pool_slots, expected_concurrency", [ @@ -291,6 +293,7 @@ def test_sync_active_worker(self): else: assert worker.state == EdgeWorkerState.IDLE + @pytest.mark.skipif(AIRFLOW_V_3_0_PLUS, reason="API only available in Airflow <3.0") def test_execute_async(self): executor, key = self.get_test_executor()