diff --git a/airflow-core/docs/migrations-ref.rst b/airflow-core/docs/migrations-ref.rst index 3b2080410e228..54a0af0a8d9e7 100644 --- a/airflow-core/docs/migrations-ref.rst +++ b/airflow-core/docs/migrations-ref.rst @@ -39,7 +39,9 @@ Here's the list of all the Database Migrations that are executed via when you ru +-------------------------+------------------+-------------------+--------------------------------------------------------------+ | Revision ID | Revises ID | Airflow Version | Description | +=========================+==================+===================+==============================================================+ -| ``e79fc784f145`` (head) | ``0b112f49112d`` | ``3.2.0`` | add timetable_type to dag table for filtering. | +| ``658517c60c7f`` (head) | ``e79fc784f145`` | ``3.2.0`` | Add ``next_trigger_id`` column to ``task_instance`` table. | ++-------------------------+------------------+-------------------+--------------------------------------------------------------+ +| ``e79fc784f145`` | ``0b112f49112d`` | ``3.2.0`` | add timetable_type to dag table for filtering. | +-------------------------+------------------+-------------------+--------------------------------------------------------------+ | ``0b112f49112d`` | ``c47f2e1ab9d4`` | ``3.2.0`` | Add exceeds max runs flag to dag model. | +-------------------------+------------------+-------------------+--------------------------------------------------------------+ diff --git a/airflow-core/src/airflow/jobs/queues.py b/airflow-core/src/airflow/jobs/queues.py new file mode 100644 index 0000000000000..54f2a12e0d303 --- /dev/null +++ b/airflow-core/src/airflow/jobs/queues.py @@ -0,0 +1,182 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from asyncio import Lock as AsyncLock, Queue +from collections import OrderedDict, defaultdict, deque +from collections.abc import Iterable, Iterator +from threading import Lock +from typing import Generic, TypeVar + +K = TypeVar("K") +V = TypeVar("V") +KV = TypeVar("KV", bound=tuple) + + +class KeyedHeadQueue(Generic[K, KV]): + """ + A keyed queue that manages values per key in insertion order. + + Features: + - `popleft()` returns only the *first value* per key (in insertion order of keys). + - Once a key's first value is popped, that key will never yield in `popleft()` again. + - Remaining values for consumed keys are preserved. + - Iteration yields those leftover (key, value) pairs. + + Example: + q = FirstValueQueue() + q.append(("task1", "event1")) + q.append(("task1", "event2")) + q.append(("task2", "eventA")) + + q.popleft() # ('task1', 'event1') + q.popleft() # ('task2', 'eventA') + + list(q) # [('task1', 'event2')] + """ + + def __init__(self) -> None: + self.__map: OrderedDict[K, deque[KV]] = OrderedDict() # key -> deque of values + self.__popped_keys: set[K] = set() # keys whose first value has been consumed + self._lock = Lock() + + @property + def _map(self) -> OrderedDict[K, list[KV]]: + with self._lock: + return OrderedDict((key, list(value)) for key, value in self.__map.items()) + + @property + def _popped_keys(self) -> set[K]: + with self._lock: + return set(self.__popped_keys) + + def get(self, key: K, default_value: list[KV] | None = None) -> list[KV] | None: + return list(self._map.get(key, default_value or [])) + + def extend(self, elements: Iterable[KV]) -> None: + for element in elements: + self.append(element) + + def append(self, element: KV) -> None: + """Append a (key, value) pair unless key already consumed.""" + key = element[0] + with self._lock: + if key not in self.__map: + self.__map[key] = deque() + self.__map[key].append(element) + + def popleft(self) -> KV: + """ + Pop the *first inserted value* for the next key in order. + + Raises IndexError if all first values have been popped. + """ + with self._lock: + for key, values in self.__map.items(): + if key not in self.__popped_keys: + value = values.popleft() + self.__popped_keys.add(key) + if not values: + del self.__map[key] + return value + raise IndexError("pop from empty KeyedHeadQueue") + + def popall(self) -> tuple[K, list[KV]]: + """ + Pop all values for the first unconsumed key (in insertion order). + + Marks the key as consumed. + Raises IndexError if no keys remain. + """ + with self._lock: + for key in self.__map.keys(): + if key not in self.__popped_keys: + values = list(self.__map.pop(key, [])) + self.__popped_keys.add(key) + return key, values + + raise IndexError("pop from empty KeyedHeadQueue") + + def __contains__(self, key: K) -> bool: + return key in self._map + + def __iter__(self) -> Iterator[tuple[K, KV]]: + """Iterate over leftover (key, value) pairs in a snapshot, so concurrent appends during iteration are not visible.""" + for key, values in self._map.items(): + for value in values: + yield key, value + + def __len__(self) -> int: + """Count remaining values available.""" + with self._lock: + return sum(len(value) for value in self.__map.values()) + + def __bool__(self) -> bool: + """Count of keys that still have their first value available.""" + with self._lock: + if not sum(1 for key in self.__map if key not in self.__popped_keys) > 0: + self.__popped_keys.clear() + return False + return True + + def keys(self) -> list[K]: + """Keys still waiting for their first value to be popped.""" + with self._lock: + return [key for key in self.__map.keys() if key not in self.__popped_keys] + + +class PartitionedQueue(Generic[K, V], defaultdict[K, Queue[tuple[K, V]]]): + """ + Dict-like container where each key maps to an asyncio.Queue. + + Tracks sizes safely for concurrent access. + Provides put(item) and popleft(). + Uses a total counter to make __bool__ O(1). + Supports both async and threading locks. + """ + + def __init__(self, maxsize: int = 0) -> None: + super().__init__(lambda: Queue(maxsize=maxsize)) + self.maxsize = maxsize + self._async_locks: dict[K, AsyncLock] = defaultdict(AsyncLock) + self._locks: dict[K, Lock] = defaultdict(Lock) + self._sizes: dict[K, int] = defaultdict(int) # track sizes per key + self._total_size: int = 0 # total items across all queues + + def __bool__(self) -> bool: + return self._total_size > 0 + + async def put(self, item: tuple[K, V]) -> None: + key = item[0] + queue = self[key] + async with self._async_locks[key]: + await queue.put(item) + with self._locks[key]: + self._sizes[key] += 1 + self._total_size += 1 + + def popleft(self) -> tuple[K, V]: + """Pop an item from the first non-empty queue synchronously (non-blocking) using thread lock.""" + for key, queue in list(self.items()): + with self._locks[key]: + if self._sizes[key] > 0: + item = queue.get_nowait() # won't raise if size > 0 + self._sizes[key] -= 1 + self._total_size -= 1 + return item + raise StopIteration diff --git a/airflow-core/src/airflow/jobs/triggerer_job_runner.py b/airflow-core/src/airflow/jobs/triggerer_job_runner.py index 532b0faf1de88..a348596a658d8 100644 --- a/airflow-core/src/airflow/jobs/triggerer_job_runner.py +++ b/airflow-core/src/airflow/jobs/triggerer_job_runner.py @@ -46,6 +46,7 @@ from airflow.executors import workloads from airflow.jobs.base_job_runner import BaseJobRunner from airflow.jobs.job import perform_heartbeat +from airflow.jobs.queues import KeyedHeadQueue, PartitionedQueue from airflow.models.trigger import Trigger from airflow.observability.trace import DebugTrace, Trace, add_debug_span from airflow.sdk.api.datamodels._generated import HITLDetailResponse @@ -81,7 +82,7 @@ from airflow.triggers import base as events from airflow.utils.helpers import log_filename_template_renderer from airflow.utils.log.logging_mixin import LoggingMixin -from airflow.utils.session import provide_session +from airflow.utils.session import create_session, provide_session if TYPE_CHECKING: from sqlalchemy.orm import Session @@ -94,6 +95,7 @@ from airflow.triggers.base import BaseTrigger logger = logging.getLogger(__name__) +maxsize = conf.getint("triggerer", "max_number_of_events_per_trigger", fallback=1) __all__ = [ "TriggerRunner", @@ -216,7 +218,7 @@ class TriggerStateChanges(BaseModel): Field(default=None), ] # Format of list[str] is the exc traceback format - failures: list[tuple[int, list[str] | None]] | None = None + failures: list[tuple[int, tuple[str, dict[str, Any]] | None, list[str] | None]] | None = None finished: list[int] | None = None class TriggerStateSync(BaseModel): @@ -364,10 +366,17 @@ class TriggerRunnerSupervisor(WatchedSubprocess): creating_triggers: deque[workloads.RunTrigger] = attrs.field(factory=deque, init=False) # Outbound queue of events - events: deque[tuple[int, events.TriggerEvent]] = attrs.field(factory=deque, init=False) + events: KeyedHeadQueue[int, tuple[int, events.TriggerEvent]] = attrs.field( + factory=KeyedHeadQueue, init=False + ) # Outbound queue of failed triggers - failed_triggers: deque[tuple[int, list[str] | None]] = attrs.field(factory=deque, init=False) + failed_triggers: KeyedHeadQueue[int, tuple[int, tuple[str, dict[str, Any]] | None, list[str] | None]] = ( + attrs.field(factory=KeyedHeadQueue, init=False) + ) + + # Outbound queue of finished triggers + finished_triggers: set = attrs.field(factory=set, init=False) def is_alive(self) -> bool: # Set by `_service_subprocess` in the loop @@ -415,6 +424,7 @@ def _handle_request(self, msg: ToTriggerSupervisor, log: FilteringBoundLogger, r for id in msg.finished or (): self.running_triggers.discard(id) self.cancelling_triggers.discard(id) + self.finished_triggers.add(id) # Remove logger from the cache, and since structlog doesn't have an explicit close method, we # only need to remove the last reference to it to close the open FH if factory := self.logger_cache.pop(id, None): @@ -575,18 +585,36 @@ def load_triggers(self): @add_debug_span def handle_events(self): """Dispatch outbound events to the Trigger model which pushes them to the relevant task instances.""" - while self.events: - # Get the event and its trigger ID - trigger_id, event = self.events.popleft() - # Tell the model to wake up its tasks - Trigger.submit_event(trigger_id=trigger_id, event=event) - # Emit stat event - Stats.incr("triggers.succeeded") + if self.events: + with create_session() as session: + while self.events: + trigger_id, event = self.events.popleft() + is_last_event = trigger_id not in self.events + remaining_events = len(self.events.get(trigger_id, [])) + log.info( + "Trigger %s has %s remaining events and %s running triggers: %s", + trigger_id, + remaining_events, + len(self.running_triggers), + len(self.running_triggers), + ) + + # Tell the model to wake up its tasks + if Trigger.submit_event( + trigger_id=trigger_id, event=event, is_last_event=is_last_event, session=session + ): + # This is temporary logging to ease debugging, will be omitted in Airflow code base + log.info("Event %s handled for trigger %s", event.payload, trigger_id) + # Emit stat event + Stats.incr("triggers.succeeded") + else: + self.events.append((trigger_id, event)) @add_debug_span def clean_unused(self): """Clean out unused or finished triggers.""" - Trigger.clean_unused() + Trigger.clean_unused(self.finished_triggers.copy()) + self.finished_triggers.clear() @add_debug_span def handle_failed_triggers(self): @@ -595,12 +623,27 @@ def handle_failed_triggers(self): Task Instances that depend on them need failing. """ - while self.failed_triggers: - # Tell the model to fail this trigger's deps - trigger_id, saved_exc = self.failed_triggers.popleft() - Trigger.submit_failure(trigger_id=trigger_id, exc=saved_exc) - # Emit stat event - Stats.incr("triggers.failed") + if self.failed_triggers: + log.info("handle_failed_triggers: %d", len(self.failed_triggers)) + with create_session() as session: + while self.failed_triggers: + trigger_id, trigger, saved_exc = self.failed_triggers.popleft() + + # Tell the model to fail this trigger's deps + if trigger_id not in self.events and Trigger.submit_failure( + trigger_id=trigger_id, trigger=trigger, exc=saved_exc, session=session + ): + log.warning("Trigger %s has failed: %s", trigger_id, saved_exc) + # Emit stat event + Stats.incr("triggers.failed") + else: + log.warning( + "Trigger %s has failed but is still processing %d remaining events, so we waiting a bit...", + trigger_id, + len(self.events.get(trigger_id)), + ) + self.failed_triggers.append((trigger_id, trigger, saved_exc)) + session.flush() def emit_metrics(self): DualStatsManager.gauge( @@ -778,6 +821,7 @@ class TriggerDetails(TypedDict): is_watcher: bool name: str events: int + trigger: tuple[str, dict[str, Any]] | None @attrs.define(kw_only=True) @@ -852,10 +896,10 @@ class TriggerRunner: to_cancel: deque[int] # Outbound queue of events - events: deque[tuple[int, events.TriggerEvent]] + events: PartitionedQueue[int, events.DiscrimatedTriggerEvent] # Outbound queue of failed triggers - failed_triggers: deque[tuple[int, BaseException | None]] + failed_triggers: KeyedHeadQueue[int, tuple[int, tuple[str, dict[str, Any]] | None, BaseException | None]] # Should-we-stop flag stop: bool = False @@ -871,8 +915,8 @@ def __init__(self): self.trigger_cache = {} self.to_create = deque() self.to_cancel = deque() - self.events = deque() - self.failed_triggers = deque() + self.events = PartitionedQueue(maxsize=maxsize) + self.failed_triggers = KeyedHeadQueue() self.job_id = None def _handle_signal(self, signum, frame) -> None: @@ -967,7 +1011,7 @@ async def create_triggers(self): except BaseException as e: # Either the trigger code or the path to it is bad. Fail the trigger. self.log.error("Trigger failed to load code", error=e, classpath=workload.classpath) - self.failed_triggers.append((trigger_id, e)) + self.failed_triggers.append((trigger_id, None, e)) continue # Loading the trigger class could have been expensive. Lets give other things a chance to run! @@ -986,7 +1030,7 @@ async def create_triggers(self): trigger_instance = trigger_class(**deserialised_kwargs) except TypeError as err: self.log.error("Trigger failed to inflate", error=err) - self.failed_triggers.append((trigger_id, err)) + self.failed_triggers.append((trigger_id, None, err)) continue trigger_instance.trigger_id = trigger_id trigger_instance.triggerer_job_id = self.job_id @@ -1016,8 +1060,13 @@ async def cancel_triggers(self): while self.to_cancel: trigger_id = self.to_cancel.popleft() if trigger_id in self.triggers: - # We only delete if it did not exit already - self.triggers[trigger_id]["task"].cancel() + # We only cancel if it did not exit already + if trigger_id not in self.failed_triggers: + await self.log.ainfo("No need to cancel trigger %s yet...", trigger_id) + elif not self.triggers[trigger_id]["task"].done(): + await self.log.ainfo("Cancelling trigger %s", trigger_id) + self.triggers[trigger_id]["task"].cancel() + pass await asyncio.sleep(0) async def cleanup_finished_triggers(self) -> list[int]: @@ -1028,13 +1077,19 @@ async def cleanup_finished_triggers(self) -> list[int]: """ finished_ids: list[int] = [] for trigger_id, details in list(self.triggers.items()): - if details["task"].done(): + await self.log.ainfo( + "trigger_id %s is %s.", trigger_id, "done" if details["task"].done() else "not done" + ) + if details["task"].done() and trigger_id not in self.events: finished_ids.append(trigger_id) # Check to see if it exited for good reasons saved_exc = None try: result = details["task"].result() - except (asyncio.CancelledError, SystemExit, KeyboardInterrupt): + except (asyncio.CancelledError, SystemExit, KeyboardInterrupt) as e: + await self.log.aexception( + "Trigger %s exited with cancelled error %s", details["name"], e, trigger_id=trigger_id + ) # These are "expected" exceptions and we stop processing here # If we don't, then the system requesting a trigger be removed - # which turns into CancelledError - results in a failure. @@ -1042,14 +1097,15 @@ async def cleanup_finished_triggers(self) -> list[int]: continue except BaseException as e: # This is potentially bad, so log it. - self.log.exception( + await self.log.aexception( "Trigger %s exited with error %s", details["name"], e, trigger_id=trigger_id ) saved_exc = e + self.failed_triggers.append((trigger_id, details.get("trigger"), saved_exc)) else: # See if they foolishly returned a TriggerEvent if isinstance(result, events.TriggerEvent): - self.log.error( + await self.log.aerror( "Trigger returned a TriggerEvent rather than yielding it", trigger=details["name"], trigger_id=trigger_id, @@ -1057,13 +1113,13 @@ async def cleanup_finished_triggers(self) -> list[int]: # See if this exited without sending an event, in which case # any task instances depending on it need to be failed if details["events"] == 0: - self.log.error( + await self.log.aerror( "Trigger exited without sending an event. Dependent tasks will be failed.", name=details["name"], trigger_id=trigger_id, ) # TODO: better formatting of the exception? - self.failed_triggers.append((trigger_id, saved_exc)) + self.failed_triggers.append((trigger_id, details.get("trigger"), saved_exc)) del self.triggers[trigger_id] await asyncio.sleep(0) return finished_ids @@ -1077,9 +1133,9 @@ async def sync_state_to_supervisor(self, finished_ids: list[int]): failures_to_send = [] while self.failed_triggers: - id, exc = self.failed_triggers.popleft() + id, trigger, exc = self.failed_triggers.popleft() tb = format_exception(type(exc), exc, exc.__traceback__) if exc else None - failures_to_send.append((id, tb)) + failures_to_send.append((id, trigger, tb)) msg = messages.TriggerStateChanges( events=events_to_send, finished=finished_ids, failures=failures_to_send @@ -1145,15 +1201,22 @@ async def run_trigger(self, trigger_id: int, trigger: BaseTrigger, timeout_after bind_log_contextvars(trigger_id=trigger_id) name = self.triggers[trigger_id]["name"] - self.log.info("trigger %s starting", name) + await self.log.ainfo("trigger %s starting", name) try: async for event in trigger.run(): await self.log.ainfo( "Trigger fired event", name=self.triggers[trigger_id]["name"], result=event ) + await self.log.ainfo( + "%s size: %d / %d", + trigger_id, + self.events[trigger_id].qsize(), + self.events[trigger_id].maxsize, + ) self.triggers[trigger_id]["events"] += 1 - self.events.append((trigger_id, event)) + await self.events.put((trigger_id, event)) except asyncio.CancelledError: + await self.log.aexception("trigger %s failed due to cancelled error", trigger_id) # We get cancelled by the scheduler changing the task state. But if we do lets give a nice error # message about it if timeout := timeout_after: @@ -1161,6 +1224,12 @@ async def run_trigger(self, trigger_id: int, trigger: BaseTrigger, timeout_after if timeout < timezone.utcnow(): await self.log.aerror("Trigger cancelled due to timeout") raise + except Exception: + await self.log.aexception("trigger %s failed", trigger_id) + # We serialize the trigger first before raising the exception, so that when the trigger is retryable, + # we can resume from the point where it failed when the scheduler recreates the trigger. + self.triggers[trigger_id]["trigger"] = trigger.serialize() + raise finally: # CancelledError will get injected when we're stopped - which is # fine, the cleanup process will understand that, but we want to diff --git a/airflow-core/src/airflow/migrations/versions/0098_3_2_0_add_next_trigger_id_to_task_instance_table.py b/airflow-core/src/airflow/migrations/versions/0098_3_2_0_add_next_trigger_id_to_task_instance_table.py new file mode 100644 index 0000000000000..431e2051f8c08 --- /dev/null +++ b/airflow-core/src/airflow/migrations/versions/0098_3_2_0_add_next_trigger_id_to_task_instance_table.py @@ -0,0 +1,47 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Add ``next_trigger_id`` column to ``task_instance`` table. + +Revision ID: 658517c60c7f +Revises: 0b112f49112d +Create Date: 2025-12-26 12:07:05.849152 + +""" + +from __future__ import annotations + +import sqlalchemy as sa +from alembic import op + +revision = "658517c60c7f" +down_revision = "0b112f49112d" +branch_labels = None +depends_on = None +airflow_version = "3.2.0" + + +def upgrade(): + """Add ``next_trigger_id`` column to ``task_instance`` table.""" + op.add_column("task_instance", sa.Column("next_trigger_id", sa.Integer(), nullable=True)) + + +def downgrade(): + """Remove ``next_trigger_id`` column from ``task_instance`` table.""" + op.drop_column("task_instance", "next_trigger_id") diff --git a/airflow-core/src/airflow/models/dagbag.py b/airflow-core/src/airflow/models/dagbag.py index abf8c41760372..ed8bf1588113a 100644 --- a/airflow-core/src/airflow/models/dagbag.py +++ b/airflow-core/src/airflow/models/dagbag.py @@ -21,9 +21,10 @@ from typing import TYPE_CHECKING, Any from sqlalchemy import String, inspect, select -from sqlalchemy.orm import Mapped, joinedload +from sqlalchemy.orm import Mapped, Session, joinedload from sqlalchemy.orm.attributes import NO_VALUE +from airflow.models import DagRun from airflow.models.base import Base, StringID from airflow.models.dag_version import DagVersion from airflow.utils.sqlalchemy import mapped_column @@ -31,9 +32,6 @@ if TYPE_CHECKING: from collections.abc import Generator - from sqlalchemy.orm import Session - - from airflow.models import DagRun from airflow.models.serialized_dag import SerializedDagModel from airflow.serialization.definitions.dag import SerializedDAG @@ -101,6 +99,12 @@ def get_latest_version_of_dag(self, dag_id: str, *, session: Session) -> Seriali return None return self._read_dag(serdag) + def get_dag(self, dag_id: str, run_id: str, *, session: Session) -> SerializedDAG | None: + dag_run = DagRun.find_duplicate(dag_id=dag_id, run_id=run_id, session=session) + if dag_run: + return self.get_dag_for_run(dag_run=dag_run, session=session) + return None + def generate_md5_hash(context): bundle_name = context.get_current_parameters()["bundle_name"] diff --git a/airflow-core/src/airflow/models/dagrun.py b/airflow-core/src/airflow/models/dagrun.py index 3f3bae8b66b9a..d8afed5516bb6 100644 --- a/airflow-core/src/airflow/models/dagrun.py +++ b/airflow-core/src/airflow/models/dagrun.py @@ -1342,6 +1342,15 @@ def recalculate(self) -> _UnfinishedStates: return schedulable_tis, callback + @classmethod + def get_dag_run(cls, dag_id: str, run_id: str, session: Session) -> DagRun | None: + return session.scalars( + select(DagRun).where( + DagRun.dag_id == dag_id, + DagRun.run_id == run_id, + ) + ).one_or_none() + @provide_session def task_instance_scheduling_decisions(self, session: Session = NEW_SESSION) -> TISchedulingDecision: tis = self.get_task_instances(session=session, state=State.task_states) @@ -1362,8 +1371,13 @@ def _filter_tis_and_exclude_removed(dag: SerializedDAG, tis: list[TI]) -> Iterab tis = list(_filter_tis_and_exclude_removed(self.get_dag(), tis)) - unfinished_tis = [t for t in tis if t.state in State.unfinished] finished_tis = [t for t in tis if t.state in State.finished] + uncompleted_tis = [ + t for t in finished_tis if t.next_trigger_id + ] # TODO: this was added to make AIP-88 work + unfinished_tis = [t for t in tis if t.state in State.unfinished] + unfinished_tis.extend(uncompleted_tis) + if unfinished_tis: schedulable_tis = [ut for ut in unfinished_tis if ut.state in SCHEDULEABLE_STATES] self.log.debug("number of scheduleable tasks for %s: %s task(s)", self, len(schedulable_tis)) @@ -1377,7 +1391,9 @@ def _filter_tis_and_exclude_removed(dag: SerializedDAG, tis: list[TI]) -> Iterab # states, so we need to re-compute. if expansion_happened: changed_tis = True - new_unfinished_tis = [t for t in unfinished_tis if t.state in State.unfinished] + new_unfinished_tis = [ + t for t in unfinished_tis if t.state in State.unfinished and not t.next_trigger_id + ] finished_tis.extend(t for t in unfinished_tis if t.state in State.finished) unfinished_tis = new_unfinished_tis else: @@ -1550,6 +1566,12 @@ def _expand_mapped_task_if_needed(ti: TI) -> Iterable[TI] | None: return expanded_tis return () + def is_unmapped_task(ti: TI) -> bool: + from airflow.sdk.definitions.mappedoperator import MappedOperator + + # TODO: check why task is still MappedOperator even when not an unmapped task anymore + return isinstance(ti.task, MappedOperator) and ti.map_index == -1 + # Check dependencies. expansion_happened = False # Set of task ids for which was already done _revise_map_indexes_if_mapped @@ -1573,7 +1595,7 @@ def _expand_mapped_task_if_needed(ti: TI) -> Iterable[TI] | None: if new_tis is not None: additional_tis.extend(new_tis) expansion_happened = True - if new_tis is None and schedulable.state in SCHEDULEABLE_STATES: + if not new_tis and schedulable.state in SCHEDULEABLE_STATES: # It's enough to revise map index once per task id, # checking the map index for each mapped task significantly slows down scheduling if schedulable.task.task_id not in revised_map_index_task_ids: @@ -1587,7 +1609,7 @@ def _expand_mapped_task_if_needed(ti: TI) -> Iterable[TI] | None: # _revise_map_indexes_if_mapped might mark the current task as REMOVED # after calculating mapped task length, so we need to re-check # the task state to ensure it's still schedulable - if schedulable.state in SCHEDULEABLE_STATES: + if not is_unmapped_task(schedulable): ready_tis.append(schedulable) # Check if any ti changed state @@ -2052,27 +2074,15 @@ def schedule_tis( empty_ti_ids: list[str] = [] schedulable_ti_ids: list[str] = [] for ti in schedulable_tis: - if ti.is_schedulable: - schedulable_ti_ids.append(ti.id) - # Check "start_trigger_args" to see whether the operator supports - # start execution from triggerer. If so, we'll check "start_from_trigger" + if not ti.is_schedulable: + empty_ti_ids.append(ti.id) + # The defer_task method will check "start_trigger_args" to see whether the operator + # start execution from triggerer. If so, we'll also check "start_from_trigger" # to see whether this feature is turned on and defer this task. # If not, we'll add this "ti" into "schedulable_ti_ids" and later # execute it to run in the worker. - # TODO TaskSDK: This is disabled since we haven't figured out how - # to render start_from_trigger in the scheduler. If we need to - # render the value in a worker, it kind of defeats the purpose of - # this feature (which is to save a worker process if possible). - # elif task.start_trigger_args is not None: - # if task.expand_start_from_trigger(context=ti.get_template_context()): - # ti.start_date = timezone.utcnow() - # if ti.state != TaskInstanceState.UP_FOR_RESCHEDULE: - # ti.try_number += 1 - # ti.defer_task(exception=None, session=session) - # else: - # schedulable_ti_ids.append(ti.id) - else: - empty_ti_ids.append(ti.id) + elif not ti.defer_task(session=session): + schedulable_ti_ids.append(ti.id) count = 0 diff --git a/airflow-core/src/airflow/models/taskinstance.py b/airflow-core/src/airflow/models/taskinstance.py index 0fb9e4e88d9e9..0e3a0a7c47365 100644 --- a/airflow-core/src/airflow/models/taskinstance.py +++ b/airflow-core/src/airflow/models/taskinstance.py @@ -115,7 +115,7 @@ from airflow.serialization.definitions.dag import SerializedDAG from airflow.serialization.definitions.mappedoperator import Operator from airflow.serialization.definitions.taskgroup import SerializedTaskGroup - + from airflow.triggers.base import StartTriggerArgs PAST_DEPENDS_MET = "past_depends_met" @@ -423,6 +423,7 @@ class TaskInstance(Base, LoggingMixin): # The trigger to resume on if we are in state DEFERRED trigger_id: Mapped[int | None] = mapped_column(Integer, nullable=True) + next_trigger_id: Mapped[int | None] = mapped_column(Integer, nullable=True) # Optional timeout utcdatetime for the trigger (past this, we'll fail) trigger_timeout: Mapped[datetime | None] = mapped_column(UtcDateTime, nullable=True) @@ -1429,6 +1430,109 @@ def update_heartbeat(self): .values(last_heartbeat_at=timezone.utcnow()) ) + @property + def start_trigger_args(self) -> StartTriggerArgs | None: + if self.task: + if self.task.is_mapped: + context = self.get_template_context() + if self.task.expand_start_from_trigger(context=context): + return self.task.expand_start_trigger_args(context=context) + elif self.task.start_from_trigger is True: + return self.task.start_trigger_args + return None + + # TODO: We have some code duplication here and in the _create_ti_state_update_query_and_update_state + # method of the task_instances module in the execution api when a TIDeferredStatePayload is being + # processed. This is because of a TaskInstance being updated differently using SQLAlchemy. + # If we use the approach from the execution api as common code in the DagRun schedule_tis method, + # the side effect is the changes done to the task instance aren't picked up by the scheduler and + # thus the task instance isn't processed until the scheduler is restarted. + @provide_session + def defer_task(self, session: Session = NEW_SESSION) -> bool: + """ + Mark the task as deferred and sets up the trigger that is needed to resume it when TaskDeferred is raised. + + :meta: private + """ + from airflow.models.trigger import Trigger + + if TYPE_CHECKING: + assert isinstance(self.task, Operator) + + # Remaining task expansion still running from previous triggerer so reschedule + if self.context_carrier and "trigger" in self.context_carrier: + trigger_classpath, trigger_kwargs = self.context_carrier.pop("trigger", (None, None)) + + self.log.info( + "Creating trigger from context_carrier for task_id %s: %s", self.task_id, trigger_kwargs + ) + trigger_row = Trigger( + classpath=trigger_classpath, + kwargs=trigger_kwargs or {}, + ) + elif not (not self.next_trigger_id and (start_trigger_args := self.start_trigger_args)): + # self.log.warning("Couldn't create trigger from start_from_trigger for task_id %s thus could not be deferred!", self.task_id) + return False + else: + trigger_kwargs = start_trigger_args.trigger_kwargs or {} + timeout = start_trigger_args.timeout + self.next_method = start_trigger_args.next_method + self.next_kwargs = start_trigger_args.next_kwargs or {} + + # Calculate timeout too if it was passed + if timeout is not None: + self.trigger_timeout = timezone.utcnow() + timeout + else: + self.trigger_timeout = None + + self.log.info( + "Creating trigger from start_trigger_args for task_id %s: %s", self.task_id, trigger_kwargs + ) + trigger_row = Trigger( + classpath=start_trigger_args.trigger_cls, + kwargs=trigger_kwargs, + ) + + # First, make the trigger entry + session.add(trigger_row) + session.flush() + + # Then, update ourselves so it matches the deferral request + # Keep an eye on the logic in `check_and_change_state_before_execution()` + # depending on self.next_method semantics + self.state = TaskInstanceState.DEFERRED + self.trigger_id = trigger_row.id + + # If an execution_timeout is set, set the timeout to the minimum of + # it and the trigger timeout + if execution_timeout := self.task.execution_timeout: + if TYPE_CHECKING: + assert self.start_date + if self.trigger_timeout: + self.trigger_timeout = min(self.start_date + execution_timeout, self.trigger_timeout) + else: + self.trigger_timeout = self.start_date + execution_timeout + self.start_date = timezone.utcnow() + if self.state != TaskInstanceState.UP_FOR_RESCHEDULE: + self.try_number += 1 + if self.test_mode: + _add_log(event=self.state, task_instance=self, session=session) + return True + + @classmethod + def get_current_max_mapping(cls, dag_id: str, task_id: str, run_id: str, session: Session) -> int: + return max( + session.scalar( + select(func.max(TaskInstance.map_index)).where( + TaskInstance.dag_id == dag_id, + TaskInstance.task_id == task_id, + TaskInstance.run_id == run_id, + ) + ) + or 0, + 0, + ) + @classmethod def fetch_handle_failure_context( cls, @@ -1466,7 +1570,11 @@ def fetch_handle_failure_context( if not test_mode: session.add(Log(TaskInstanceState.FAILED.value, ti)) - ti.clear_next_method_args() + # Only clear next method args if first invocation on triggerer failed + if ( + not ti.next_trigger_id + ): # TODO: this check is very important, otherwise failed triggers will clear the XCom's + ti.clear_next_method_args() # Set state correctly and figure out how to log it and decide whether # to email diff --git a/airflow-core/src/airflow/models/taskmap.py b/airflow-core/src/airflow/models/taskmap.py index 0106b1a147054..213d6a18252b6 100644 --- a/airflow-core/src/airflow/models/taskmap.py +++ b/airflow-core/src/airflow/models/taskmap.py @@ -24,7 +24,7 @@ from collections.abc import Collection, Iterable, Sequence from typing import TYPE_CHECKING, Any -from sqlalchemy import CheckConstraint, ForeignKeyConstraint, Integer, String, func, or_, select +from sqlalchemy import CheckConstraint, ForeignKeyConstraint, Integer, String, or_, select from sqlalchemy.orm import Mapped from airflow.models.base import COLLATION_ARGS, ID_LEN, TaskInstanceDependencies @@ -122,6 +122,16 @@ def variant(self) -> TaskMapVariant: return TaskMapVariant.LIST return TaskMapVariant.DICT + @classmethod + def get_task_map_length(cls, dag_id: str, task_id: str, run_id: str, session: Session) -> int | None: + return session.scalar( + select(TaskMap.length).where( + TaskMap.dag_id == dag_id, + TaskMap.task_id == task_id, + TaskMap.run_id == run_id, + ) + ) + @classmethod def expand_mapped_task( cls, @@ -152,7 +162,13 @@ def expand_mapped_task( ) try: - total_length: int | None = get_mapped_ti_count(task, run_id, session=session) + total_length: int | None = TaskMap.get_task_map_length( + dag_id=task.dag_id, task_id=task.task_id, run_id=run_id, session=session + ) + if not total_length: + total_length = get_mapped_ti_count(task, run_id, session=session) + else: + task = next((op for op in task.get_direct_relatives(upstream=False) if op.is_mapped), task) except NotFullyPopulated as e: if not task.dag or not task.dag.partial: task.log.error( @@ -163,17 +179,20 @@ def expand_mapped_task( ) total_length = None - state: str | None = None - unmapped_ti: TaskInstance | None = session.scalars( - select(TaskInstance).where( - TaskInstance.dag_id == task.dag_id, - TaskInstance.task_id == task.task_id, - TaskInstance.run_id == run_id, - TaskInstance.map_index == -1, - or_(TaskInstance.state.in_(State.unfinished), TaskInstance.state.is_(None)), - ) - ).one_or_none() - + state: TaskInstanceState | None = None + unmapped_ti: TaskInstance | None = ( + session.scalars( + select(TaskInstance).where( + TaskInstance.dag_id == task.dag_id, + TaskInstance.task_id == task.task_id, + TaskInstance.run_id == run_id, + TaskInstance.map_index == -1, + or_(TaskInstance.state.in_(State.unfinished), TaskInstance.state.is_(None)), + ) + ).one_or_none() + if task and task.is_mapped + else None + ) all_expanded_tis: list[TaskInstance] = [] if unmapped_ti: @@ -223,15 +242,8 @@ def expand_mapped_task( indexes_to_map: Iterable[int] = () else: # Only create "missing" ones. - current_max_mapping = ( - session.scalar( - select(func.max(TaskInstance.map_index)).where( - TaskInstance.dag_id == task.dag_id, - TaskInstance.task_id == task.task_id, - TaskInstance.run_id == run_id, - ) - ) - or 0 + current_max_mapping = TaskInstance.get_current_max_mapping( + dag_id=task.dag_id, task_id=task.task_id, run_id=run_id, session=session ) indexes_to_map = range(current_max_mapping + 1, total_length) diff --git a/airflow-core/src/airflow/models/trigger.py b/airflow-core/src/airflow/models/trigger.py index 942b811ce6509..db0a41d3fc01e 100644 --- a/airflow-core/src/airflow/models/trigger.py +++ b/airflow-core/src/airflow/models/trigger.py @@ -24,7 +24,7 @@ from traceback import format_exception from typing import TYPE_CHECKING, Any -from sqlalchemy import Integer, String, Text, delete, func, or_, select, update +from sqlalchemy import Integer, String, Text, delete, exists, func, or_, select, update from sqlalchemy.ext.associationproxy import association_proxy from sqlalchemy.orm import Mapped, Session, relationship, selectinload from sqlalchemy.sql.functions import coalesce @@ -40,7 +40,7 @@ from airflow.utils.retries import run_with_db_retries from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.sqlalchemy import UtcDateTime, get_dialect_name, mapped_column, with_row_locks -from airflow.utils.state import TaskInstanceState +from airflow.utils.state import State, TaskInstanceState if TYPE_CHECKING: from sqlalchemy import Row @@ -218,13 +218,26 @@ def fetch_trigger_ids_with_non_task_associations(cls, session: Session = NEW_SES @classmethod @provide_session - def clean_unused(cls, session: Session = NEW_SESSION) -> None: + def clean_unused(cls, finished_triggers: set | None = None, session: Session = NEW_SESSION) -> None: """ Delete all triggers that have no tasks dependent on them and are not associated to an asset. Triggers have a one-to-many relationship to task instances, so we need to clean those up first. Afterward we can drop the triggers not referenced by anyone. """ + # TODO: should be moved into dedicated method + if finished_triggers: + session.execute( + update(TaskInstance) + .where( + or_( + TaskInstance.trigger_id.in_(finished_triggers), + TaskInstance.next_trigger_id.in_(finished_triggers), + ) + ) + .values(trigger_id=None, next_trigger_id=None) + ) + # Update all task instances with trigger IDs that are not DEFERRED to remove them for attempt in run_with_db_retries(): with attempt: @@ -237,12 +250,15 @@ def clean_unused(cls, session: Session = NEW_SESSION) -> None: ) # Get all triggers that have no task instances, assets, or callbacks depending on them and delete them - ids = ( - select(cls.id) - .where(~cls.assets.any(), ~cls.callback.has()) - .join(TaskInstance, cls.id == TaskInstance.trigger_id, isouter=True) - .group_by(cls.id) - .having(func.count(TaskInstance.trigger_id) == 0) + ids = select(Trigger.id).where( + # no TIs referencing trigger_id that are not failed + ~exists().where(TaskInstance.trigger_id == Trigger.id), + # no TIs referencing next_trigger_id that are not failed + ~exists().where(TaskInstance.next_trigger_id == Trigger.id), + # no assets + ~cls.assets.any(), + # no callback + ~cls.callback.has(), ) if get_dialect_name(session) == "mysql": # MySQL doesn't support DELETE with JOIN, so we need to do it in two steps @@ -257,7 +273,9 @@ def clean_unused(cls, session: Session = NEW_SESSION) -> None: @classmethod @provide_session - def submit_event(cls, trigger_id, event: TriggerEvent, session: Session = NEW_SESSION) -> None: + def submit_event( + cls, trigger_id, event: TriggerEvent, is_last_event: bool = True, session: Session = NEW_SESSION + ) -> bool: """ Fire an event. @@ -265,30 +283,58 @@ def submit_event(cls, trigger_id, event: TriggerEvent, session: Session = NEW_SE Send an event to all assets associated to the trigger. """ # Resume deferred tasks - for task_instance in session.scalars( + task_instances = session.scalars( select(TaskInstance).where( - TaskInstance.trigger_id == trigger_id, TaskInstance.state == TaskInstanceState.DEFERRED + or_( + TaskInstance.trigger_id == trigger_id, + TaskInstance.next_trigger_id == trigger_id, + # We need to do this as once we run the next_method, trigger_id is removed from TaskInstance + ), + TaskInstance.state.in_([TaskInstanceState.DEFERRED, TaskInstanceState.SUCCESS]), + # TODO: SUCCESS might become COMPLETED ) - ): - handle_event_submit(event, task_instance=task_instance, session=session) + ).all() + + if task_instances: + log.info("Handle event for trigger %s", trigger_id) + + # Resume deferred tasks + for task_instance in task_instances: + handle_event_submit( + event, + trigger_id=trigger_id, + task_instance=task_instance, + is_last_event=is_last_event, + session=session, + ) - # Send an event to assets - trigger = session.scalars(select(cls).where(cls.id == trigger_id)).one_or_none() - if trigger is None: - # Already deleted for some reason - return - for asset in trigger.assets: - AssetManager.register_asset_change( - asset=asset.to_serialized(), - extra={"from_trigger": True, "payload": event.payload}, - session=session, - ) - if trigger.callback: - trigger.callback.handle_event(event, session) + # Send an event to assets + trigger = session.scalars( + select(Trigger).options(selectinload(Trigger.assets)).where(Trigger.id == trigger_id) + ).one_or_none() + if not trigger: + # Already deleted for some reason + return False + for asset in trigger.assets: + AssetManager.register_asset_change( + asset=asset.to_public(), + extra={"from_trigger": True, "payload": event.payload}, + session=session, + ) + if trigger.callback: + trigger.callback.handle_event(event, session) + return True + + log.debug( + "No more task instances found for trigger %s! Stop processing events for trigger %s", + trigger_id, + trigger_id, + ) + return False @classmethod @provide_session - def submit_failure(cls, trigger_id, exc=None, session: Session = NEW_SESSION) -> None: + def submit_failure(cls, trigger_id, trigger: dict, exc=None, session: Session = NEW_SESSION) -> bool: """ When a trigger has failed unexpectedly, mark everything that depended on it as failed. @@ -300,6 +346,41 @@ def submit_failure(cls, trigger_id, exc=None, session: Session = NEW_SESSION) -> the runtime code understands as immediate-fail, and pack the error into next_kwargs. """ + if trigger: + unfinished_tis = session.scalar( + select(func.count()) + .select_from(TaskInstance) + .where( + TaskInstance.next_trigger_id == trigger_id, + ~TaskInstance.state.in_(State.finished_dr_states), + ) + .execution_options(populate_existing=True) + ) + + log.debug("unfinished_tis: %d", unfinished_tis) + + if unfinished_tis == 0: + task_instances = list( + session.scalars( + select(TaskInstance).where( + TaskInstance.next_trigger_id == trigger_id, + # TaskInstance.state.in_([TaskInstanceState.RUNNING, TaskInstanceState.SUCCESS]), + ) + ) + ) + + log.debug("task_instances: %d", len(task_instances)) + + for task_instance in task_instances: + task_instance.next_trigger_id = None + task_instance.context_carrier = { + **(task_instance.context_carrier or {}), + **{"trigger": trigger}, + } + task_instance.set_state(TaskInstanceState.UP_FOR_RETRY) + return True + return False + for task_instance in session.scalars( select(TaskInstance).where( TaskInstance.trigger_id == trigger_id, TaskInstance.state == TaskInstanceState.DEFERRED @@ -321,6 +402,8 @@ def submit_failure(cls, trigger_id, exc=None, session: Session = NEW_SESSION) -> task_instance.state = TaskInstanceState.SCHEDULED task_instance.scheduled_dttm = timezone.utcnow() + return False + @classmethod @provide_session def ids_for_triggerer( @@ -452,7 +535,14 @@ def get_sorted_triggers( @singledispatch -def handle_event_submit(event: TriggerEvent, *, task_instance: TaskInstance, session: Session) -> None: +def handle_event_submit( + event: TriggerEvent, + *, + trigger_id: int, + task_instance: TaskInstance, + is_last_event: bool = True, + session: Session, +) -> None: """ Handle the submit event for a given task instance. @@ -492,11 +582,25 @@ def handle_event_submit(event: TriggerEvent, *, task_instance: TaskInstance, ses # Set the state of the task instance to scheduled task_instance.state = TaskInstanceState.SCHEDULED task_instance.scheduled_dttm = timezone.utcnow() + + if is_last_event: + task_instance.next_trigger_id = None + else: + log.info("trigger %s is not last event to be processed...", trigger_id) + task_instance.try_number = 0 + task_instance.next_trigger_id = trigger_id + session.flush() @handle_event_submit.register -def _(event: BaseTaskEndEvent, *, task_instance: TaskInstance, session: Session) -> None: +def _( + event: BaseTaskEndEvent, + *, + task_instance: TaskInstance, + session: Session, + **_: Any, +) -> None: """ Submit event for the given task instance. diff --git a/airflow-core/src/airflow/utils/db.py b/airflow-core/src/airflow/utils/db.py index d444689e26b56..ca7357844db12 100644 --- a/airflow-core/src/airflow/utils/db.py +++ b/airflow-core/src/airflow/utils/db.py @@ -112,7 +112,7 @@ class MappedClassProtocol(Protocol): "3.0.0": "29ce7909c52b", "3.0.3": "fe199e1abd77", "3.1.0": "cc92b33c6709", - "3.2.0": "e79fc784f145", + "3.2.0": "658517c60c7f", } # Prefix used to identify tables holding data moved during migration. diff --git a/airflow-core/tests/unit/jobs/test_triggerer_job.py b/airflow-core/tests/unit/jobs/test_triggerer_job.py index df31c9225c272..7833517ce6cbe 100644 --- a/airflow-core/tests/unit/jobs/test_triggerer_job.py +++ b/airflow-core/tests/unit/jobs/test_triggerer_job.py @@ -307,7 +307,13 @@ class TestTriggerRunner: def test_run_inline_trigger_canceled(self, session) -> None: trigger_runner = TriggerRunner() trigger_runner.triggers = { - 1: {"task": MagicMock(spec=asyncio.Task), "is_watcher": False, "name": "mock_name", "events": 0} + 1: { + "task": MagicMock(spec=asyncio.Task), + "is_watcher": False, + "name": "mock_name", + "events": 0, + "trigger": None, + } } mock_trigger = MagicMock(spec=BaseTrigger) mock_trigger.timeout_after = None @@ -320,7 +326,13 @@ def test_run_inline_trigger_canceled(self, session) -> None: def test_run_inline_trigger_timeout(self, session, cap_structlog) -> None: trigger_runner = TriggerRunner() trigger_runner.triggers = { - 1: {"task": MagicMock(spec=asyncio.Task), "is_watcher": False, "name": "mock_name", "events": 0} + 1: { + "task": MagicMock(spec=asyncio.Task), + "is_watcher": False, + "name": "mock_name", + "events": 0, + "trigger": None, + } } mock_trigger = MagicMock(spec=BaseTrigger) mock_trigger.run.side_effect = asyncio.CancelledError() diff --git a/airflow-core/tests/unit/models/test_taskinstance.py b/airflow-core/tests/unit/models/test_taskinstance.py index 5636d93d4a317..81e177e24f25e 100644 --- a/airflow-core/tests/unit/models/test_taskinstance.py +++ b/airflow-core/tests/unit/models/test_taskinstance.py @@ -2416,6 +2416,7 @@ def test_refresh_from_db(self, create_task_instance): "trigger_id": None, "next_kwargs": None, "next_method": None, + "next_trigger_id": None, "updated_at": None, "task_display_name": "Test Refresh from DB Task", "dag_version_id": mock.ANY,