Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 16 additions & 11 deletions airflow-core/src/airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -2063,25 +2063,30 @@ def schedule_tis(
empty_ti_ids.append(ti.id)

count = 0

# Don't only check if the TI.id is in id_chunk
# but also check if the TI.state is in the schedulable states.
# Plus, a scheduled empty operator should not be scheduled again.
non_null_schedulable_states = tuple(s for s in SCHEDULEABLE_STATES if s is not None)
schedulable_state_clause = or_(
TI.state.is_(None),
TI.state.in_(non_null_schedulable_states),
)
Comment on lines +2070 to +2073
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should do some benchmarking/checking of indexes on this query if you haven't already. TI state is almost certainly indexed, but we should check this with some EXPLAIN ANALYZE in a moderatly sized DB if we have one (something to the region of 1-2m TI rows?)

Might also be worth checking if a schedulable_state_clause = TI.state != TaskInstanceState.SCHEDULED is enough and more-performant?

next_try_number = case(
(TI.state == TaskInstanceState.UP_FOR_RESCHEDULE, TI.try_number),
else_=TI.try_number + 1,
)
if schedulable_ti_ids:
schedulable_ti_ids_chunks = chunks(
schedulable_ti_ids, max_tis_per_query or len(schedulable_ti_ids)
)
for id_chunk in schedulable_ti_ids_chunks:
result = session.execute(
update(TI)
.where(TI.id.in_(id_chunk))
.where(TI.id.in_(id_chunk), schedulable_state_clause)
.values(
state=TaskInstanceState.SCHEDULED,
scheduled_dttm=timezone.utcnow(),
try_number=case(
(
or_(TI.state.is_(None), TI.state != TaskInstanceState.UP_FOR_RESCHEDULE),
TI.try_number + 1,
),
else_=TI.try_number,
),
try_number=next_try_number,
)
.execution_options(synchronize_session=False)
)
Expand All @@ -2093,13 +2098,13 @@ def schedule_tis(
for id_chunk in dummy_ti_ids_chunks:
result = session.execute(
update(TI)
.where(TI.id.in_(id_chunk))
.where(TI.id.in_(id_chunk), schedulable_state_clause)
.values(
state=TaskInstanceState.SUCCESS,
start_date=timezone.utcnow(),
end_date=timezone.utcnow(),
duration=0,
try_number=TI.try_number + 1,
try_number=next_try_number,
)
.execution_options(
synchronize_session=False,
Expand Down
253 changes: 252 additions & 1 deletion airflow-core/tests/unit/models/test_dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,12 @@

import pendulum
import pytest
from sqlalchemy import exists, func, select
from sqlalchemy import (
exists,
func,
select,
update,
)
from sqlalchemy.orm import joinedload

from airflow import settings
Expand All @@ -51,6 +56,7 @@
from airflow.serialization.serialized_objects import LazyDeserializedDAG
from airflow.task.trigger_rule import TriggerRule
from airflow.triggers.base import StartTriggerArgs
from airflow.utils.session import create_session
from airflow.utils.span_status import SpanStatus
from airflow.utils.state import DagRunState, State, TaskInstanceState
from airflow.utils.thread_safe_dict import ThreadSafeDict
Expand Down Expand Up @@ -2046,6 +2052,251 @@ def test_schedule_tis_map_index(dag_maker, session):
assert ti2.state == TaskInstanceState.SUCCESS


def test_schedule_tis_does_not_increment_try_number_if_ti_already_queued_by_other_scheduler(
dag_maker, session
):
with dag_maker(session=session) as dag:
BashOperator(task_id="task", bash_command="echo 1")

dr = dag_maker.create_dagrun(session=session)
ti = dr.get_task_instance("task", session=session)
assert ti is not None
ti.refresh_from_task(dag.get_task("task"))
assert ti.state is None

# The stale scheduler picks try 1.
ti.try_number = 1
session.flush()
session.commit()

# Another scheduler already queued the TI in DB (same try).
with create_session() as other_session:
filter_for_tis = TI.filter_for_tis([ti])
assert filter_for_tis is not None
other_session.execute(
update(TI)
.where(filter_for_tis)
.values(
state=TaskInstanceState.QUEUED,
try_number=1,
)
.execution_options(synchronize_session=False)
)

# This stale scheduler still has a stale TI object; schedule_tis must be a no-op.
assert dr.schedule_tis((ti,), session=session) == 0

refreshed_ti = session.scalar(
select(TI).where(
TI.dag_id == ti.dag_id,
TI.task_id == ti.task_id,
TI.run_id == ti.run_id,
TI.map_index == ti.map_index,
)
)
assert refreshed_ti.state == TaskInstanceState.QUEUED
assert refreshed_ti.try_number == 1


def test_schedule_tis_empty_operator_does_not_short_circuit_if_ti_already_queued(dag_maker, session):
with dag_maker(session=session) as dag:
EmptyOperator(task_id="empty_task")

dr = dag_maker.create_dagrun(session=session)
ti = dr.get_task_instance("empty_task", session=session)
ti.refresh_from_task(dag.get_task("empty_task"))
assert ti.state is None

# Stale scheduler picks TI
ti.try_number = 1
session.flush()
session.commit()

# Another scheduler already queued it.
with create_session() as other_session:
filter_for_tis = TI.filter_for_tis([ti])
assert filter_for_tis is not None
other_session.execute(
update(TI)
.where(filter_for_tis)
.values(
state=TaskInstanceState.QUEUED,
try_number=1,
)
.execution_options(synchronize_session=False)
)

# no shortcircuit
assert dr.schedule_tis((ti,), session=session) == 0

refreshed_ti = session.scalar(
select(TI).where(
TI.dag_id == ti.dag_id,
TI.task_id == ti.task_id,
TI.run_id == ti.run_id,
TI.map_index == ti.map_index,
)
)
assert refreshed_ti is not None
assert refreshed_ti.state == TaskInstanceState.QUEUED
assert refreshed_ti.try_number == 1


def test_schedule_tis_up_for_reschedule_does_not_increment_try_number(dag_maker, session):
with dag_maker(session=session) as dag:
BashOperator(task_id="task", bash_command="echo 1")

dr = dag_maker.create_dagrun(session=session)
ti = dr.get_task_instance("task", session=session)
ti.refresh_from_task(dag.get_task("task"))

ti.state = TaskInstanceState.UP_FOR_RESCHEDULE
ti.try_number = 3
session.commit()

assert dr.schedule_tis((ti,), session=session) == 1
session.commit()

# schedule_tis uses synchronize_session=False, so the session may still hold a stale instance.
# Expire the identity map so the SELECT reflects the DB row.
session.expire_all()
refreshed_ti = session.scalar(
select(TI).where(
TI.dag_id == ti.dag_id,
TI.task_id == ti.task_id,
TI.run_id == ti.run_id,
TI.map_index == ti.map_index,
)
)
assert refreshed_ti.state == TaskInstanceState.SCHEDULED
assert refreshed_ti.try_number == 3


def test_schedule_tis_is_noop_if_ti_transitions_to_nonschedulable_state_before_update(dag_maker, session):
from airflow.utils.session import create_session

with dag_maker(session=session) as dag:
BashOperator(task_id="task", bash_command="echo 1")

dr = dag_maker.create_dagrun(session=session)
ti = dr.get_task_instance("task", session=session)
ti.refresh_from_task(dag.get_task("task"))

# First scheduler view.
ti.try_number = 1
session.commit()

# Another scheduler already queued it.
with create_session() as other_session:
filter_for_tis = TI.filter_for_tis([ti])
assert filter_for_tis is not None
other_session.execute(
update(TI)
.where(filter_for_tis)
.values(
state=TaskInstanceState.QUEUED,
try_number=1,
)
.execution_options(synchronize_session=False)
)
# attempting a schedule from first scheduler(diff session) results in 0 schedule.
assert dr.schedule_tis((ti,), session=session) == 0

refreshed_ti = session.scalar(
select(TI).where(
TI.dag_id == ti.dag_id,
TI.task_id == ti.task_id,
TI.run_id == ti.run_id,
TI.map_index == ti.map_index,
)
)
assert refreshed_ti.state == TaskInstanceState.QUEUED
assert refreshed_ti.try_number == 1


def test_schedule_tis_empty_operator_is_noop_if_ti_already_running(dag_maker, session):
from airflow.utils.session import create_session

with dag_maker(session=session) as dag:
EmptyOperator(task_id="empty_task")

dr = dag_maker.create_dagrun(session=session)
ti = dr.get_task_instance("empty_task", session=session)
ti.refresh_from_task(dag.get_task("empty_task"))

ti.try_number = 3
session.commit()

with create_session() as other_session:
filter_for_tis = TI.filter_for_tis([ti])
assert filter_for_tis is not None
other_session.execute(
update(TI)
.where(filter_for_tis)
.values(
state=TaskInstanceState.RUNNING,
try_number=3,
)
.execution_options(synchronize_session=False)
)

assert dr.schedule_tis((ti,), session=session) == 0

refreshed_ti = session.scalar(
select(TI).where(
TI.dag_id == ti.dag_id,
TI.task_id == ti.task_id,
TI.run_id == ti.run_id,
TI.map_index == ti.map_index,
)
)
assert refreshed_ti.state == TaskInstanceState.RUNNING
assert refreshed_ti.try_number == 3


def test_schedule_tis_only_one_scheduler_update_succeeds_when_competing(dag_maker, session):
from airflow.utils.session import create_session

with dag_maker(session=session) as dag:
BashOperator(task_id="task", bash_command="echo 1")

dr = dag_maker.create_dagrun(session=session)
ti = dr.get_task_instance("task", session=session)
ti.refresh_from_task(dag.get_task("task"))
assert ti.state is None

ti.try_number = 0
session.commit()

# Scheduler A schedules first.
assert dr.schedule_tis((ti,), session=session) == 1
session.commit()

# Scheduler B (stale view) tries again in a new session; should be a no-op.
with create_session() as scheduler_b_session:
ti_b = scheduler_b_session.scalar(
select(TI).where(
TI.dag_id == ti.dag_id,
TI.task_id == ti.task_id,
TI.run_id == ti.run_id,
TI.map_index == ti.map_index,
)
)
assert ti_b is not None
assert dr.schedule_tis((ti_b,), session=scheduler_b_session) == 0

refreshed_ti = session.scalar(
select(TI).where(
TI.dag_id == ti.dag_id,
TI.task_id == ti.task_id,
TI.run_id == ti.run_id,
TI.map_index == ti.map_index,
)
)
assert refreshed_ti.state == TaskInstanceState.SCHEDULED
assert refreshed_ti.try_number == 1


@pytest.mark.xfail(reason="We can't keep this behaviour with remote workers where scheduler can't reach xcom")
@pytest.mark.need_serialized_dag
def test_schedule_tis_start_trigger(dag_maker, session):
Expand Down
Loading