From 3ad9273e7bddd863c933fc56f7212ff300691b78 Mon Sep 17 00:00:00 2001 From: Ephraim Anierobi Date: Fri, 9 Jan 2026 14:47:29 +0100 Subject: [PATCH] Fix HA scheduler try_number double increment MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit In HA, two scheduler processes can race to schedule the same TaskInstance. Previously DagRun.schedule_tis() updated rows by ti.id alone, so a scheduler could increment try_number and transition state even after another scheduler had already advanced the TI (e.g. to SCHEDULED/QUEUED), resulting in duplicate attempts being queued. This change makes scheduling idempotent under HA races by: - Guarding schedule_tis() DB updates to only apply when the TI is still in schedulable states (derived from SCHEDULEABLE_STATES, handling NULL explicitly). - Using a single CASE (next_try_number) so reschedules (UP_FOR_RESCHEDULE) do not start a new try, and applying this consistently to both normal scheduling and the EmptyOperator fast-path. Adds regression tests covering: - TI already queued by another scheduler. - EmptyOperator fast-path blocked when TI is already QUEUED/RUNNING. - UP_FOR_RESCHEDULE scheduling keeps try_number unchanged. - Only one “scheduler” update succeeds when competing. --- airflow-core/src/airflow/models/dagrun.py | 27 +- airflow-core/tests/unit/models/test_dagrun.py | 253 +++++++++++++++++- 2 files changed, 268 insertions(+), 12 deletions(-) diff --git a/airflow-core/src/airflow/models/dagrun.py b/airflow-core/src/airflow/models/dagrun.py index bd3973b77683d..052ddc1a45947 100644 --- a/airflow-core/src/airflow/models/dagrun.py +++ b/airflow-core/src/airflow/models/dagrun.py @@ -2063,7 +2063,18 @@ def schedule_tis( empty_ti_ids.append(ti.id) count = 0 - + # Don't only check if the TI.id is in id_chunk + # but also check if the TI.state is in the schedulable states. + # Plus, a scheduled empty operator should not be scheduled again. + non_null_schedulable_states = tuple(s for s in SCHEDULEABLE_STATES if s is not None) + schedulable_state_clause = or_( + TI.state.is_(None), + TI.state.in_(non_null_schedulable_states), + ) + next_try_number = case( + (TI.state == TaskInstanceState.UP_FOR_RESCHEDULE, TI.try_number), + else_=TI.try_number + 1, + ) if schedulable_ti_ids: schedulable_ti_ids_chunks = chunks( schedulable_ti_ids, max_tis_per_query or len(schedulable_ti_ids) @@ -2071,17 +2082,11 @@ def schedule_tis( for id_chunk in schedulable_ti_ids_chunks: result = session.execute( update(TI) - .where(TI.id.in_(id_chunk)) + .where(TI.id.in_(id_chunk), schedulable_state_clause) .values( state=TaskInstanceState.SCHEDULED, scheduled_dttm=timezone.utcnow(), - try_number=case( - ( - or_(TI.state.is_(None), TI.state != TaskInstanceState.UP_FOR_RESCHEDULE), - TI.try_number + 1, - ), - else_=TI.try_number, - ), + try_number=next_try_number, ) .execution_options(synchronize_session=False) ) @@ -2093,13 +2098,13 @@ def schedule_tis( for id_chunk in dummy_ti_ids_chunks: result = session.execute( update(TI) - .where(TI.id.in_(id_chunk)) + .where(TI.id.in_(id_chunk), schedulable_state_clause) .values( state=TaskInstanceState.SUCCESS, start_date=timezone.utcnow(), end_date=timezone.utcnow(), duration=0, - try_number=TI.try_number + 1, + try_number=next_try_number, ) .execution_options( synchronize_session=False, diff --git a/airflow-core/tests/unit/models/test_dagrun.py b/airflow-core/tests/unit/models/test_dagrun.py index 636012ff8c044..f805f40486181 100644 --- a/airflow-core/tests/unit/models/test_dagrun.py +++ b/airflow-core/tests/unit/models/test_dagrun.py @@ -27,7 +27,12 @@ import pendulum import pytest -from sqlalchemy import exists, func, select +from sqlalchemy import ( + exists, + func, + select, + update, +) from sqlalchemy.orm import joinedload from airflow import settings @@ -51,6 +56,7 @@ from airflow.serialization.serialized_objects import LazyDeserializedDAG from airflow.task.trigger_rule import TriggerRule from airflow.triggers.base import StartTriggerArgs +from airflow.utils.session import create_session from airflow.utils.span_status import SpanStatus from airflow.utils.state import DagRunState, State, TaskInstanceState from airflow.utils.thread_safe_dict import ThreadSafeDict @@ -2046,6 +2052,251 @@ def test_schedule_tis_map_index(dag_maker, session): assert ti2.state == TaskInstanceState.SUCCESS +def test_schedule_tis_does_not_increment_try_number_if_ti_already_queued_by_other_scheduler( + dag_maker, session +): + with dag_maker(session=session) as dag: + BashOperator(task_id="task", bash_command="echo 1") + + dr = dag_maker.create_dagrun(session=session) + ti = dr.get_task_instance("task", session=session) + assert ti is not None + ti.refresh_from_task(dag.get_task("task")) + assert ti.state is None + + # The stale scheduler picks try 1. + ti.try_number = 1 + session.flush() + session.commit() + + # Another scheduler already queued the TI in DB (same try). + with create_session() as other_session: + filter_for_tis = TI.filter_for_tis([ti]) + assert filter_for_tis is not None + other_session.execute( + update(TI) + .where(filter_for_tis) + .values( + state=TaskInstanceState.QUEUED, + try_number=1, + ) + .execution_options(synchronize_session=False) + ) + + # This stale scheduler still has a stale TI object; schedule_tis must be a no-op. + assert dr.schedule_tis((ti,), session=session) == 0 + + refreshed_ti = session.scalar( + select(TI).where( + TI.dag_id == ti.dag_id, + TI.task_id == ti.task_id, + TI.run_id == ti.run_id, + TI.map_index == ti.map_index, + ) + ) + assert refreshed_ti.state == TaskInstanceState.QUEUED + assert refreshed_ti.try_number == 1 + + +def test_schedule_tis_empty_operator_does_not_short_circuit_if_ti_already_queued(dag_maker, session): + with dag_maker(session=session) as dag: + EmptyOperator(task_id="empty_task") + + dr = dag_maker.create_dagrun(session=session) + ti = dr.get_task_instance("empty_task", session=session) + ti.refresh_from_task(dag.get_task("empty_task")) + assert ti.state is None + + # Stale scheduler picks TI + ti.try_number = 1 + session.flush() + session.commit() + + # Another scheduler already queued it. + with create_session() as other_session: + filter_for_tis = TI.filter_for_tis([ti]) + assert filter_for_tis is not None + other_session.execute( + update(TI) + .where(filter_for_tis) + .values( + state=TaskInstanceState.QUEUED, + try_number=1, + ) + .execution_options(synchronize_session=False) + ) + + # no shortcircuit + assert dr.schedule_tis((ti,), session=session) == 0 + + refreshed_ti = session.scalar( + select(TI).where( + TI.dag_id == ti.dag_id, + TI.task_id == ti.task_id, + TI.run_id == ti.run_id, + TI.map_index == ti.map_index, + ) + ) + assert refreshed_ti is not None + assert refreshed_ti.state == TaskInstanceState.QUEUED + assert refreshed_ti.try_number == 1 + + +def test_schedule_tis_up_for_reschedule_does_not_increment_try_number(dag_maker, session): + with dag_maker(session=session) as dag: + BashOperator(task_id="task", bash_command="echo 1") + + dr = dag_maker.create_dagrun(session=session) + ti = dr.get_task_instance("task", session=session) + ti.refresh_from_task(dag.get_task("task")) + + ti.state = TaskInstanceState.UP_FOR_RESCHEDULE + ti.try_number = 3 + session.commit() + + assert dr.schedule_tis((ti,), session=session) == 1 + session.commit() + + # schedule_tis uses synchronize_session=False, so the session may still hold a stale instance. + # Expire the identity map so the SELECT reflects the DB row. + session.expire_all() + refreshed_ti = session.scalar( + select(TI).where( + TI.dag_id == ti.dag_id, + TI.task_id == ti.task_id, + TI.run_id == ti.run_id, + TI.map_index == ti.map_index, + ) + ) + assert refreshed_ti.state == TaskInstanceState.SCHEDULED + assert refreshed_ti.try_number == 3 + + +def test_schedule_tis_is_noop_if_ti_transitions_to_nonschedulable_state_before_update(dag_maker, session): + from airflow.utils.session import create_session + + with dag_maker(session=session) as dag: + BashOperator(task_id="task", bash_command="echo 1") + + dr = dag_maker.create_dagrun(session=session) + ti = dr.get_task_instance("task", session=session) + ti.refresh_from_task(dag.get_task("task")) + + # First scheduler view. + ti.try_number = 1 + session.commit() + + # Another scheduler already queued it. + with create_session() as other_session: + filter_for_tis = TI.filter_for_tis([ti]) + assert filter_for_tis is not None + other_session.execute( + update(TI) + .where(filter_for_tis) + .values( + state=TaskInstanceState.QUEUED, + try_number=1, + ) + .execution_options(synchronize_session=False) + ) + # attempting a schedule from first scheduler(diff session) results in 0 schedule. + assert dr.schedule_tis((ti,), session=session) == 0 + + refreshed_ti = session.scalar( + select(TI).where( + TI.dag_id == ti.dag_id, + TI.task_id == ti.task_id, + TI.run_id == ti.run_id, + TI.map_index == ti.map_index, + ) + ) + assert refreshed_ti.state == TaskInstanceState.QUEUED + assert refreshed_ti.try_number == 1 + + +def test_schedule_tis_empty_operator_is_noop_if_ti_already_running(dag_maker, session): + from airflow.utils.session import create_session + + with dag_maker(session=session) as dag: + EmptyOperator(task_id="empty_task") + + dr = dag_maker.create_dagrun(session=session) + ti = dr.get_task_instance("empty_task", session=session) + ti.refresh_from_task(dag.get_task("empty_task")) + + ti.try_number = 3 + session.commit() + + with create_session() as other_session: + filter_for_tis = TI.filter_for_tis([ti]) + assert filter_for_tis is not None + other_session.execute( + update(TI) + .where(filter_for_tis) + .values( + state=TaskInstanceState.RUNNING, + try_number=3, + ) + .execution_options(synchronize_session=False) + ) + + assert dr.schedule_tis((ti,), session=session) == 0 + + refreshed_ti = session.scalar( + select(TI).where( + TI.dag_id == ti.dag_id, + TI.task_id == ti.task_id, + TI.run_id == ti.run_id, + TI.map_index == ti.map_index, + ) + ) + assert refreshed_ti.state == TaskInstanceState.RUNNING + assert refreshed_ti.try_number == 3 + + +def test_schedule_tis_only_one_scheduler_update_succeeds_when_competing(dag_maker, session): + from airflow.utils.session import create_session + + with dag_maker(session=session) as dag: + BashOperator(task_id="task", bash_command="echo 1") + + dr = dag_maker.create_dagrun(session=session) + ti = dr.get_task_instance("task", session=session) + ti.refresh_from_task(dag.get_task("task")) + assert ti.state is None + + ti.try_number = 0 + session.commit() + + # Scheduler A schedules first. + assert dr.schedule_tis((ti,), session=session) == 1 + session.commit() + + # Scheduler B (stale view) tries again in a new session; should be a no-op. + with create_session() as scheduler_b_session: + ti_b = scheduler_b_session.scalar( + select(TI).where( + TI.dag_id == ti.dag_id, + TI.task_id == ti.task_id, + TI.run_id == ti.run_id, + TI.map_index == ti.map_index, + ) + ) + assert ti_b is not None + assert dr.schedule_tis((ti_b,), session=scheduler_b_session) == 0 + + refreshed_ti = session.scalar( + select(TI).where( + TI.dag_id == ti.dag_id, + TI.task_id == ti.task_id, + TI.run_id == ti.run_id, + TI.map_index == ti.map_index, + ) + ) + assert refreshed_ti.state == TaskInstanceState.SCHEDULED + assert refreshed_ti.try_number == 1 + + @pytest.mark.xfail(reason="We can't keep this behaviour with remote workers where scheduler can't reach xcom") @pytest.mark.need_serialized_dag def test_schedule_tis_start_trigger(dag_maker, session):